Note that all users who use Vital DB, an open biosignal dataset, must agree to the Data Use Agreement below. If you do not agree, please close this window. The Data Use Agreement is available here: https://vitaldb.net/dataset/#h.vcpgs1yemdb5
For the Project Draft submission see the DL4H_Team_24_Project_Draft.ipynb notebook in the project repository.
The project repository can be found at: https://github.com/abarrie2/cs598-dlh-project
This project aims to reproduce findings from the paper titled "Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study" by Jo Y-Y et al. (2022) [1]. This study introduces a deep learning model that predicts intraoperative hypotension (IOH) events before they occur, utilizing a combination of arterial blood pressure (ABP), electroencephalogram (EEG), and electrocardiogram (ECG) signals.
Intraoperative hypotension (IOH) is a common and significant surgical complication defined by a mean arterial pressure drop below 65 mmHg. It is associated with increased risks of myocardial infarction, acute kidney injury, and heightened postoperative mortality. Effective prediction and timely intervention can substantially enhance patient outcomes.
Initial attempts to predict IOH primarily used arterial blood pressure (ABP) waveforms. A foundational study by Hatib F et al. (2018) titled "Machine-learning Algorithm to Predict Hypotension Based on High-fidelity Arterial Pressure Waveform Analysis" [2] showed that machine learning could forecast IOH events using ABP with reasonable accuracy. This finding spurred further research into utilizing various physiological signals for IOH prediction.
Subsequent advancements included the development of the Acumen™ hypotension prediction index, which was studied in "AcumenTM hypotension prediction index guidance for prevention and treatment of hypotension in noncardiac surgery: a prospective, single-arm, multicenter trial" by Bao X et al. (2024) [3]. This trial integrated a hypotension prediction index into blood pressure monitoring equipment, demonstrating its effectiveness in reducing the number and duration of IOH events during surgeries. Further study is needed to determine whether this resultant reduction in IOH events transalates into improved postoperative patient outcomes.
Building on these advancements, the paper by Jo Y-Y et al. (2022) proposes a deep learning approach that enhances prediction accuracy by incorporating EEG and ECG signals along with ABP. This multi-modal method, evaluated over prediction windows of 3, 5, 10, and 15 minutes, aims to provide a comprehensive physiological profile that could predict IOH more accurately and earlier. Their results indicate that the combination of ABP and EEG significantly improves performance metrics such as AUROC and AUPRC, outperforming models that use fewer signals or different combinations.
Our project seeks to reproduce and verify Jo Y-Y et al.'s results to assess whether this integrated approach can indeed improve IOH prediction accuracy, thereby potentially enhancing surgical safety and patient outcomes.
The original paper investigated the following hypotheses:
Results were compared using AUROC and AUPRC scores. Based on the results described in the original paper, we expect that Hypothesis 2 will be confirmed, and that Hypotheses 1 and 3 will not be confirmed.
In order to perform the corresponding experiments, we will implement a CNN-based model that can be configured to train and infer using the following four model variations:
We will measure the performance of these configurations using the same AUROC and AUPRC metrics as used in the original paper. To test hypothesis 1 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 2. To test hypothesis 2 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 3. To test hypothesis 3 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 4. For all of the above measures and experiment combinations, we will operate multiple experiments where the time-to-IOH event prediction will use the following prediction windows:
In the event that we are compute-bound, we will prioritize the 3-minute prediction window experiments as they are the most relevant to the original paper's findings.
The predictive power of ABP, ECG and ABP + ECG models at 3-, 5-, 10- and 15-minute prediction windows:
In order to demonstrate the functioning of the code in a short (ie, <8 minute limit) the following options and modifications were used:
MAX_CASES was set to 20. The total number of cases to be used in the full training set is 3296, but the smaller numbers allows demonstration of each section of the pipeline.vitaldb_cache is prepopulated in Google Colab. The cache file is approx. 800MB and contains the raw and mini-fied copies of the source dataset and is downloaded from Google Drive. This is much faster than using the vitaldb API, but is again only a fraction of the data. The full dataset can be downloaded with the API or prepopulated by following the instructions in the "Bulk Data Download" section below.max_epochs is set to 6. With the small dataset, training is fast and shows the decreasing training and validation losses. In the full model run, max_epochs will be set to 100. In both cases early stopping is enabled and will stop training if the validation losses stop decreasing for five consecutive epochs.The methodology section is composed of the following subsections: Environment, Data and Model.
The environment setup differs based on whether you are running the code on a local machine or on Google Colab. The following sections provide instructions for setting up the environment in each case.
Create conda environment for the project using the environment.yml file:
conda env create --prefix .envs/dlh-team24 -f environment.yml
Activate the environment with:
conda activate .envs/dlh-team24
The following code snippet installs the required packages and downloads the necessary files in a Google Colab environment:
# Google Colab environments have a `/content` directory. Use this as a proxy for running Colab-only code
COLAB_ENV = "google.colab" in str(get_ipython())
if COLAB_ENV:
#install vitaldb
%pip install vitaldb
# Executing in Colab therefore download cached preprocessed data.
# TODO: Integrate this with the setup local cache data section below.
# Check for file existence before overwriting.
import gdown
gdown.download(id="15b5Nfhgj3McSO2GmkVUKkhSSxQXX14hJ", output="vitaldb_cache.tgz")
!tar -zxf vitaldb_cache.tgz
# Download sqi_filter.csv from github repo
!wget https://raw.githubusercontent.com/abarrie2/cs598-dlh-project/main/sqi_filter.csv
All other required packages are already installed in the Google Colab environment.
# Import packages
import os
import random
import copy
from collections import defaultdict
from timeit import default_timer as timer
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_curve, auc, confusion_matrix
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay, average_precision_score
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
import torch
from torch.utils.data import Dataset
import vitaldb
import h5py
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from datetime import datetime
Set random seeds to generate consistent results:
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(RANDOM_SEED)
Data for this project is sourced from the open biosignal VitalDB dataset as described in "VitalDB, a high-fidelity multi-parameter vital signs database in surgical patients" by Lee H-C et al. (2022) [4], which contains perioperative vital signs and numerical data from 6,388 cases of non-cardiac (general, thoracic, urological, and gynecological) surgery patients who underwent routine or emergency surgery at Seoul National University Hospital between 2016 and 2017. The dataset includes ABP, ECG, and EEG signals, as well as other physiological data. The dataset is available through an API and Python library, and at PhysioNet: https://physionet.org/content/vitaldb/1.0.0/
Characteristics of the dataset: | Characteristic | Value | Details | |-----------------------|-----------------------------|------------------------| | Total number of cases | 6,388 | | | Sex (male) | 3,243 (50.8%) | | | Age (years) | 59 | Range: 48-68 | | Height (cm) | 162 | Range: 156-169 | | Weight (kg) | 61 | Range: 53-69 | | Tram-Rac 4A tracks | 6,355 (99.5%) | Sampling rate: 500Hz | | BIS Vista tracks | 5,566 (87.1%) | Sampling rate: 128Hz | | Case duration (min) | 189 | Range: 27-1041 |
Labels are only known after processing the data. In the original paper, there were an average of 1.6 IOH events per case and 5.7 non-events per case so we expect approximately 10,221 IOH events and 364,116 non-events in the dataset.
Data will be processed as follows:
VitalDB data is static, so local copies can be stored and reused to avoid expensive downloads and to speed up data processing.
The default directory defined below is in the project .gitignore file. If this is modified, the new directory should also be added to the project .gitignore.
VITALDB_CACHE = './vitaldb_cache'
VITAL_ALL = f"{VITALDB_CACHE}/vital_all"
VITAL_MINI = f"{VITALDB_CACHE}/vital_mini"
VITAL_METADATA = f"{VITALDB_CACHE}/metadata"
VITAL_MODELS = f"{VITALDB_CACHE}/models"
VITAL_PREPROCESS_SCRATCH = f"{VITALDB_CACHE}/data_scratch"
VITAL_EXTRACTED_SEGMENTS = f"{VITALDB_CACHE}/segments"
TRACK_CACHE = None
SEGMENT_CACHE = None
# when USE_MEMORY_CACHING is enabled, track data will be persisted in an in-memory cache. Not useful once we have already pre-extracted all event segments
# DON'T USE: Stores items in memory that are later not used. Causes OOM on segment extraction.
USE_MEMORY_CACHING = False
# When RESET_CACHE is set to True, it will ensure the TRACK_CACHE is disposed and recreated when we do dataset initialization.
# Use as a shortcut to wiping cache rather than restarting kernel
RESET_CACHE = False
#PREDICTION_WINDOW = 5
PREDICTION_WINDOW = 'ALL'
ALL_PREDICTION_WINDOWS = [3, 5, 10, 15]
# Maximum number of cases of interest for which to download data.
# Set to a small value (ex: 20) for demo purposes, else set to None to disable and download and process all.
MAX_CASES = None
#MAX_CASES = 200
# Preloading Cases: when true, all matched cases will have the _mini tracks extracted and put into in-mem dict
PRELOADING_CASES = False
PRELOADING_SEGMENTS = True
# Perform Data Preprocessing: do we want to take the raw vital file and extract segments of interest for training?
PERFORM_DATA_PREPROCESSING = False
if not os.path.exists(VITALDB_CACHE):
os.mkdir(VITALDB_CACHE)
if not os.path.exists(VITAL_ALL):
os.mkdir(VITAL_ALL)
if not os.path.exists(VITAL_MINI):
os.mkdir(VITAL_MINI)
if not os.path.exists(VITAL_METADATA):
os.mkdir(VITAL_METADATA)
if not os.path.exists(VITAL_MODELS):
os.mkdir(VITAL_MODELS)
if not os.path.exists(VITAL_PREPROCESS_SCRATCH):
os.mkdir(VITAL_PREPROCESS_SCRATCH)
if not os.path.exists(VITAL_EXTRACTED_SEGMENTS):
os.mkdir(VITAL_EXTRACTED_SEGMENTS)
print(os.listdir(VITALDB_CACHE))
['segments_bak', '.DS_Store', 'vital_all', 'models', 'docs', 'vital_mini.tar', 'data_scratch', 'osfs', 'vital_mini', 'metadata', 'segments_bak_0428_00', 'segments', 'models_old']
This step is not required, but will significantly speed up downstream processing and avoid a high volume of API requests to the VitalDB web site.
The cache population code checks if the .vital files are locally available, and can be populated by calling the vitaldb API or by manually prepopulating the cache (recommended)
wget -r -N -c -np https://physionet.org/files/vitaldb/1.0.0/ to download the files in a terminalvital_files into the ${VITAL_ALL} directory.# Returns the Pandas DataFrame for the specified dataset.
# One of 'cases', 'labs', or 'trks'
# If the file exists locally, create and return the DataFrame.
# Else, download and cache the csv first, then return the DataFrame.
def vitaldb_dataframe_loader(dataset_name):
if dataset_name not in ['cases', 'labs', 'trks']:
raise ValueError(f'Invalid dataset name: {dataset_name}')
file_path = f'{VITAL_METADATA}/{dataset_name}.csv'
if os.path.isfile(file_path):
print(f'{dataset_name}.csv exists locally.')
df = pd.read_csv(file_path)
return df
else:
print(f'downloading {dataset_name} and storing in the local cache for future reuse.')
df = pd.read_csv(f'https://api.vitaldb.net/{dataset_name}')
df.to_csv(file_path, index=False)
return df
cases = vitaldb_dataframe_loader('cases')
cases = cases.set_index('caseid')
cases.shape
cases.csv exists locally.
(6388, 73)
cases.index.nunique()
6388
cases.head()
| subjectid | casestart | caseend | anestart | aneend | opstart | opend | adm | dis | icu_days | ... | intraop_colloid | intraop_ppf | intraop_mdz | intraop_ftn | intraop_rocu | intraop_vecu | intraop_eph | intraop_phe | intraop_epi | intraop_ca | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| caseid | |||||||||||||||||||||
| 1 | 5955 | 0 | 11542 | -552 | 10848.0 | 1668 | 10368 | -236220 | 627780 | 0 | ... | 0 | 120 | 0.0 | 100 | 70 | 0 | 10 | 0 | 0 | 0 |
| 2 | 2487 | 0 | 15741 | -1039 | 14921.0 | 1721 | 14621 | -221160 | 1506840 | 0 | ... | 0 | 150 | 0.0 | 0 | 100 | 0 | 20 | 0 | 0 | 0 |
| 3 | 2861 | 0 | 4394 | -590 | 4210.0 | 1090 | 3010 | -218640 | 40560 | 0 | ... | 0 | 0 | 0.0 | 0 | 50 | 0 | 0 | 0 | 0 | 0 |
| 4 | 1903 | 0 | 20990 | -778 | 20222.0 | 2522 | 17822 | -201120 | 576480 | 1 | ... | 0 | 80 | 0.0 | 100 | 100 | 0 | 50 | 0 | 0 | 0 |
| 5 | 4416 | 0 | 21531 | -1009 | 22391.0 | 2591 | 20291 | -67560 | 3734040 | 13 | ... | 0 | 0 | 0.0 | 0 | 160 | 0 | 10 | 900 | 0 | 2100 |
5 rows × 73 columns
cases['sex'].value_counts()
sex M 3243 F 3145 Name: count, dtype: int64
trks = vitaldb_dataframe_loader('trks')
trks = trks.set_index('caseid')
trks.shape
trks.csv exists locally.
(486449, 2)
trks.index.nunique()
6388
trks.groupby('caseid')[['tid']].count().plot();
trks.groupby('caseid')[['tid']].count().hist();
trks.groupby('tname').count().sort_values(by='tid', ascending=False)
| tid | |
|---|---|
| tname | |
| Solar8000/HR | 6387 |
| Solar8000/PLETH_SPO2 | 6386 |
| Solar8000/PLETH_HR | 6386 |
| Primus/CO2 | 6362 |
| Primus/PAMB_MBAR | 6361 |
| ... | ... |
| Orchestra/AMD_VOL | 1 |
| Solar8000/ST_V5 | 1 |
| Orchestra/NPS_VOL | 1 |
| Orchestra/AMD_RATE | 1 |
| Orchestra/VEC_VOL | 1 |
196 rows × 1 columns
SNUADC/ART
arterial blood pressure waveform
Parameter, Description, Type/Hz, Unit
SNUADC/ART, Arterial pressure wave, W/500, mmHg
trks[trks['tname'].str.contains('SNUADC/ART')].shape
(3645, 2)
SNUADC/ECG_II
electrocardiogram waveform
Parameter, Description, Type/Hz, Unit
SNUADC/ECG_II, ECG lead II wave, W/500, mV
trks[trks['tname'].str.contains('SNUADC/ECG_II')].shape
(6355, 2)
BIS/EEG1_WAV
electroencephalogram waveform
Parameter, Description, Type/Hz, Unit
BIS/EEG1_WAV, EEG wave from channel 1, W/128, uV
trks[trks['tname'].str.contains('BIS/EEG1_WAV')].shape
(5871, 2)
These are the subset of case ids for which modelling and analysis will be performed based upon inclusion criteria and waveform data availability.
# TRACK NAMES is used for metadata analysis via API
TRACK_NAMES = ['SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV']
TRACK_SRATES = [500, 500, 128]
# EXTRACTION TRACK NAMES adds the EVENT track which is only used when doing actual file i/o
EXTRACTION_TRACK_NAMES = ['SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV', 'EVENT']
EXTRACTION_TRACK_SRATES = [500, 500, 128, 1]
# As in the paper, select cases which meet the following criteria:
#
# For patients, the inclusion criteria were as follows:
# (1) adults (age >= 18)
# (2) administered general anaesthesia
# (3) undergone non-cardiac surgery.
#
# For waveform data, the inclusion criteria were as follows:
# (1) no missing monitoring for ABP, ECG, and EEG waveforms
# (2) no cases containing false events or non-events due to poor signal quality
# (checked in second stage of data preprocessing)
# Adult
inclusion_1 = cases.loc[cases['age'] >= 18].index
print(f'{len(cases)-len(inclusion_1)} cases excluded, {len(inclusion_1)} remaining due to age criteria')
# General Anesthesia
inclusion_2 = cases.loc[cases['ane_type'] == 'General'].index
print(f'{len(cases)-len(inclusion_2)} cases excluded, {len(inclusion_2)} remaining due to anesthesia criteria')
# Non-cardiac surgery
inclusion_3 = cases.loc[
~cases['opname'].str.contains("cardiac", case=False)
& ~cases['opname'].str.contains("aneurysmal", case=False)
].index
print(f'{len(cases)-len(inclusion_3)} cases excluded, {len(inclusion_3)} remaining due to non-cardiac surgery criteria')
# ABP, ECG, EEG waveforms
inclusion_4 = trks.loc[trks['tname'].isin(TRACK_NAMES)].index.value_counts()
inclusion_4 = inclusion_4[inclusion_4 == len(TRACK_NAMES)].index
print(f'{len(cases)-len(inclusion_4)} cases excluded, {len(inclusion_4)} remaining due to missing waveform data')
# SQI filter
# NOTE: this depends on a sqi_filter.csv generated by external processing
inclusion_5 = pd.read_csv('sqi_filter.csv', header=None, names=['caseid','sqi']).set_index('caseid').index
print(f'{len(cases)-len(inclusion_5)} cases excluded, {len(inclusion_5)} remaining due to SQI threshold not being met')
# Only include cases with known good waveforms.
exclusion_6 = pd.read_csv('malformed_tracks_filter.csv', header=None, names=['caseid']).set_index('caseid').index
inclusion_6 = cases.index.difference(exclusion_6)
print(f'{len(cases)-len(inclusion_6)} cases excluded, {len(inclusion_6)} remaining due to malformed waveforms')
cases_of_interest_idx = inclusion_1 \
.intersection(inclusion_2) \
.intersection(inclusion_3) \
.intersection(inclusion_4) \
.intersection(inclusion_5) \
.intersection(inclusion_6)
cases_of_interest = cases.loc[cases_of_interest_idx]
print()
print(f'{cases_of_interest_idx.shape[0]} out of {cases.shape[0]} total cases remaining after exclusions applied')
# Trim cases of interest to MAX_CASES
if MAX_CASES:
cases_of_interest_idx = cases_of_interest_idx[:MAX_CASES]
print(f'{cases_of_interest_idx.shape[0]} cases of interest selected')
57 cases excluded, 6331 remaining due to age criteria 345 cases excluded, 6043 remaining due to anesthesia criteria 14 cases excluded, 6374 remaining due to non-cardiac surgery criteria 3019 cases excluded, 3369 remaining due to missing waveform data 0 cases excluded, 6388 remaining due to SQI threshold not being met 186 cases excluded, 6202 remaining due to malformed waveforms 3110 out of 6388 total cases remaining after exclusions applied 3110 cases of interest selected
cases_of_interest.head(n=5)
| subjectid | casestart | caseend | anestart | aneend | opstart | opend | adm | dis | icu_days | ... | intraop_colloid | intraop_ppf | intraop_mdz | intraop_ftn | intraop_rocu | intraop_vecu | intraop_eph | intraop_phe | intraop_epi | intraop_ca | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| caseid | |||||||||||||||||||||
| 1 | 5955 | 0 | 11542 | -552 | 10848.0 | 1668 | 10368 | -236220 | 627780 | 0 | ... | 0 | 120 | 0.0 | 100 | 70 | 0 | 10 | 0 | 0 | 0 |
| 4 | 1903 | 0 | 20990 | -778 | 20222.0 | 2522 | 17822 | -201120 | 576480 | 1 | ... | 0 | 80 | 0.0 | 100 | 100 | 0 | 50 | 0 | 0 | 0 |
| 7 | 5124 | 0 | 15770 | 477 | 14817.0 | 3177 | 14577 | -154320 | 623280 | 3 | ... | 0 | 0 | 0.0 | 0 | 120 | 0 | 0 | 0 | 0 | 0 |
| 10 | 2175 | 0 | 20992 | -1743 | 21057.0 | 2457 | 19857 | -220740 | 3580860 | 1 | ... | 0 | 90 | 0.0 | 0 | 110 | 0 | 20 | 500 | 0 | 600 |
| 12 | 491 | 0 | 31203 | -220 | 31460.0 | 5360 | 30860 | -208500 | 1519500 | 4 | ... | 200 | 100 | 0.0 | 100 | 70 | 0 | 20 | 0 | 0 | 3300 |
5 rows × 73 columns
These are the subset of tracks (waveforms) for the cases of interest identified above.
# A single case maps to one or more waveform tracks. Select only the tracks required for analysis.
trks_of_interest = trks.loc[cases_of_interest_idx][trks.loc[cases_of_interest_idx]['tname'].isin(TRACK_NAMES)]
trks_of_interest.shape
(9330, 2)
trks_of_interest.head(n=5)
| tname | tid | |
|---|---|---|
| caseid | ||
| 1 | BIS/EEG1_WAV | 0aa685df768489a18a5e9f53af0d83bf60890c73 |
| 1 | SNUADC/ART | 724cdd7184d7886b8f7de091c5b135bd01949959 |
| 1 | SNUADC/ECG_II | 8c9161aaae8cb578e2aa7b60f44234d98d2b3344 |
| 4 | BIS/EEG1_WAV | 1b4c2379be3397a79d3787dd810190150dc53f27 |
| 4 | SNUADC/ART | e28777c4706fe3a5e714bf2d91821d22d782d802 |
trks_of_interest_idx = trks_of_interest.set_index('tid').index
trks_of_interest_idx.shape
(9330,)
Tracks data are large and therefore expensive to download every time used.
By default, the .vital file format stores all tracks for each case internally. Since only select tracks per case are required, each .vital file can be further reduced by discarding the unused tracks.
# Ensure the full vital file dataset is available for cases of interest.
count_downloaded = 0
count_present = 0
#for i, idx in enumerate(cases.index):
for idx in cases_of_interest_idx:
full_path = f'{VITAL_ALL}/{idx:04d}.vital'
if not os.path.isfile(full_path):
print(f'Missing vital file: {full_path}')
# Download and save the file.
vf = vitaldb.VitalFile(idx)
vf.to_vital(full_path)
count_downloaded += 1
else:
count_present += 1
print()
print(f'Count of cases of interest: {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files downloaded: {count_downloaded}')
print(f'Count of vital files already present: {count_present}')
Count of cases of interest: 3110 Count of vital files downloaded: 0 Count of vital files already present: 3110
# Convert vital files to "mini" versions including only the subset of tracks defined in TRACK_NAMES above.
# Only perform conversion for the cases of interest.
# NOTE: If this cell is interrupted, it can be restarted and will continue where it left off.
count_minified = 0
count_present = 0
count_missing_tracks = 0
count_not_fixable = 0
vf = vitaldb.VitalFile('./vitaldb_cache/vital_all/0001.vital', EXTRACTION_TRACK_NAMES)
print(vf)
# If set to true, local mini files are checked for all tracks even if already present.
FORCE_VALIDATE = False
for idx in cases_of_interest_idx:
full_path = f'{VITAL_ALL}/{idx:04d}.vital'
mini_path = f'{VITAL_MINI}/{idx:04d}_mini.vital'
if FORCE_VALIDATE or not os.path.isfile(mini_path):
print(f'Creating mini vital file: {idx}')
vf = vitaldb.VitalFile(full_path, EXTRACTION_TRACK_NAMES)
if len(vf.get_track_names()) != 4:
print(f'Missing track in vital file: {idx}, {set(EXTRACTION_TRACK_NAMES).difference(set(vf.get_track_names()))}')
count_missing_tracks += 1
# Attempt to download from VitalDB directly and see if missing tracks are present.
vf = vitaldb.VitalFile(idx, EXTRACTION_TRACK_NAMES)
if len(vf.get_track_names()) != 3:
print(f'Unable to fix missing tracks: {idx}')
count_not_fixable += 1
continue
if vf.get_track_samples(EXTRACTION_TRACK_NAMES[0], 1/EXTRACTION_TRACK_SRATES[0]).shape[0] == 0:
print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[0]}')
count_not_fixable += 1
continue
if vf.get_track_samples(EXTRACTION_TRACK_NAMES[1], 1/EXTRACTION_TRACK_SRATES[1]).shape[0] == 0:
print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[1]}')
count_not_fixable += 1
continue
if vf.get_track_samples(EXTRACTION_TRACK_NAMES[2], 1/EXTRACTION_TRACK_SRATES[2]).shape[0] == 0:
print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[2]}')
count_not_fixable += 1
continue
# if vf.get_track_samples(EXTRACTION_TRACK_NAMES[3], 1/EXTRACTION_TRACK_SRATES[3]).shape[0] == 0:
# print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[3]}')
# count_not_fixable += 1
# continue
vf.to_vital(mini_path)
count_minified += 1
else:
count_present += 1
print()
print(f'Count of cases of interest: {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files minified: {count_minified}')
print(f'Count of vital files already present: {count_present}')
print(f'Count of vital files missing tracks: {count_missing_tracks}')
print(f'Count of vital files not fixable: {count_not_fixable}')
VitalFile('./vitaldb_cache/vital_all/0001.vital', '['EVENT', 'SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV']')
Count of cases of interest: 3110
Count of vital files minified: 0
Count of vital files already present: 3110
Count of vital files missing tracks: 0
Count of vital files not fixable: 0
# Convert vital files to "mini" versions including only the subset of tracks defined in TRACK_NAMES above.
# Only perform conversion for the cases of interest.
# NOTE: If this cell is interrupted, it can be restarted and will continue where it left off.
count_missing_tracks = 0
# If true, perform fast validate that all mini files have 3 tracks.
FORCE_VALIDATE = False
if FORCE_VALIDATE:
for idx in cases_of_interest_idx:
mini_path = f'{VITAL_MINI}/{idx:04d}_mini.vital'
if os.path.isfile(mini_path):
vf = vitaldb.VitalFile(mini_path)
if len(vf.get_track_names()) != 3:
print(f'Missing track in vital file: {idx}, {set(TRACK_NAMES).difference(set(vf.get_track_names()))}')
count_missing_tracks += 1
print()
print(f'Count of cases of interest: {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files missing tracks: {count_missing_tracks}')
Count of cases of interest: 3110 Count of vital files missing tracks: 0
Preprocessing characteristics are different for each of the three signal categories:
apply_bandpass_filter() implements the bandpass filter using scipy.signal
apply_zscore_normalization() implements the Z-score normalization using numpy
from scipy.signal import butter, lfilter, spectrogram
# define two methods for data preprocessing
def apply_bandpass_filter(data, lowcut, highcut, fs, order=5):
b, a = butter(order, [lowcut, highcut], fs=fs, btype='band')
y = lfilter(b, a, np.nan_to_num(data))
return y
def apply_zscore_normalization(signal):
mean = np.nanmean(signal)
std = np.nanstd(signal)
return (signal - mean) / std
# Filtering Demonstration
# temp experimental, code to be incorporated into overall preloader process
# for now it's just dumping example plots of the before/after filtered signal data
caseidx = 1
file_path = f"{VITAL_MINI}/{caseidx:04d}_mini.vital"
vf = vitaldb.VitalFile(file_path, TRACK_NAMES)
originalAbp = None
filteredAbp = None
originalEcg = None
filteredEcg = None
originalEeg = None
filteredEeg = None
ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"
for i, (track_name, rate) in enumerate(zip(TRACK_NAMES, TRACK_SRATES)):
# Get samples for this track
track_samples = vf.get_track_samples(track_name, 1/rate)
#track_samples, _ = vf.get_samples(track_name, 1/rate)
print(f"Track {track_name} @ {rate}Hz shape {len(track_samples)}")
if track_name == ABP_TRACK_NAME:
# ABP waveforms are used without further pre-processing
originalAbp = track_samples
filteredAbp = track_samples
elif track_name == ECG_TRACK_NAME:
originalEcg = track_samples
# ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
# first apply bandpass filter
filteredEcg = apply_bandpass_filter(track_samples, 1, 40, rate)
# then do z-score normalization
filteredEcg = apply_zscore_normalization(filteredEcg)
elif track_name == EEG_TRACK_NAME:
# EEG waveforms are band-pass filtered between 0.5 and 50 Hz
originalEeg = track_samples
filteredEeg = apply_bandpass_filter(track_samples, 0.5, 50, rate, 2)
def plotSignal(data, title):
plt.figure(figsize=(20, 5))
plt.plot(data)
plt.title(title)
plt.show()
plotSignal(originalAbp, "Original ABP")
plotSignal(originalAbp, "Unfiltered ABP")
plotSignal(originalEcg, "Original ECG")
plotSignal(filteredEcg, "Filtered ECG")
plotSignal(originalEeg, "Original EEG")
plotSignal(filteredEeg, "Filtered EEG")
Track SNUADC/ART @ 500Hz shape 5770575 Track SNUADC/ECG_II @ 500Hz shape 5770575 Track BIS/EEG1_WAV @ 128Hz shape 1477268
# Preprocess data tracks
ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"
EVENT_TRACK_NAME = "EVENT"
MINI_FILE_FOLDER = VITAL_MINI
CACHE_FILE_FOLDER = VITAL_PREPROCESS_SCRATCH
if RESET_CACHE:
TRACK_CACHE = None
SEGMENT_CACHE = None
if TRACK_CACHE is None:
TRACK_CACHE = {}
SEGMENT_CACHE = {}
def get_track_data(case, print_when_file_loaded = False):
parsedFile = None
abp = None
eeg = None
ecg = None
events = None
for i, (track_name, rate) in enumerate(zip(EXTRACTION_TRACK_NAMES, EXTRACTION_TRACK_SRATES)):
# use integer case id and track name, delimited by pipe, as cache key
cache_label = f"{case}|{track_name}"
if cache_label not in TRACK_CACHE:
if parsedFile is None:
file_path = f"{MINI_FILE_FOLDER}/{case:04d}_mini.vital"
if print_when_file_loaded:
print(f"[{datetime.now()}] Loading vital file {file_path}")
parsedFile = vitaldb.VitalFile(file_path, EXTRACTION_TRACK_NAMES)
dataset = np.array(parsedFile.get_track_samples(track_name, 1/rate))
if track_name == ABP_TRACK_NAME:
# no filtering for ABP
abp = dataset
abp = pd.DataFrame(abp).ffill(axis=0).bfill(axis=0)[0].values
if USE_MEMORY_CACHING:
TRACK_CACHE[cache_label] = abp
elif track_name == ECG_TRACK_NAME:
ecg = dataset
# apply ECG filtering: first bandpass then do z-score normalization
ecg = pd.DataFrame(ecg).ffill(axis=0).bfill(axis=0)[0].values
ecg = apply_bandpass_filter(ecg, 1, 40, rate, 2)
ecg = apply_zscore_normalization(ecg)
if USE_MEMORY_CACHING:
TRACK_CACHE[cache_label] = ecg
elif track_name == EEG_TRACK_NAME:
eeg = dataset
eeg = pd.DataFrame(eeg).ffill(axis=0).bfill(axis=0)[0].values
# apply EEG filtering: bandpass only
eeg = apply_bandpass_filter(eeg, 0.5, 50, rate, 2)
if USE_MEMORY_CACHING:
TRACK_CACHE[cache_label] = eeg
elif track_name == EVENT_TRACK_NAME:
events = dataset
if USE_MEMORY_CACHING:
TRACK_CACHE[cache_label] = events
else:
# cache hit, pull from cache
if track_name == ABP_TRACK_NAME:
abp = TRACK_CACHE[cache_label]
elif track_name == ECG_TRACK_NAME:
ecg = TRACK_CACHE[cache_label]
elif track_name == EEG_TRACK_NAME:
eeg = TRACK_CACHE[cache_label]
elif track_name == EVENT_TRACK_NAME:
events = TRACK_CACHE[cache_label]
return (abp, ecg, eeg, events)
# ABP waveforms are used without further pre-processing
# ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
# EEG waveforms are band-pass filtered between 0.5 and 50 Hz
if PRELOADING_CASES:
# determine disk cache file label
maxlabel = "ALL"
if MAX_CASES is not None:
maxlabel = str(MAX_CASES)
picklefile = f"{CACHE_FILE_FOLDER}/{PREDICTION_WINDOW}_minutes_MAX{maxlabel}.trackcache"
for track in tqdm(cases_of_interest_idx):
# getting track data will cause a cache-check and fill when missing
# will also apply appropriate filtering per track
get_track_data(track, False)
print(f"Generated track cache, {len(TRACK_CACHE)} records generated")
def get_segment_data(file_path):
abp = None
eeg = None
ecg = None
if USE_MEMORY_CACHING:
if file_path in SEGMENT_CACHE:
(abp, ecg, eeg) = SEGMENT_CACHE[file_path]
return (abp, ecg, eeg)
try:
with h5py.File(file_path, 'r') as f:
abp = np.array(f['abp'])
ecg = np.array(f['ecg'])
eeg = np.array(f['eeg'])
abp = np.array(abp)
eeg = np.array(eeg)
ecg = np.array(ecg)
if len(abp) > 30000:
abp = abp[:30000]
elif len(ecg) < 30000:
abp = np.resize(abp, (30000))
if len(ecg) > 30000:
ecg = ecg[:30000]
elif len(ecg) < 30000:
ecg = np.resize(ecg, (30000))
if len(eeg) > 7680:
eeg = eeg[:7680]
elif len(eeg) < 7680:
eeg = np.resize(eeg, (7680))
if USE_MEMORY_CACHING:
SEGMENT_CACHE[file_path] = (abp, ecg, eeg)
except:
abp = None
ecg = None
eeg = None
return (abp, ecg, eeg)
The following method is adapted from the preprocessing block of reference [6] (https://github.com/vitaldb/examples/blob/master/hypotension_art.ipynb)
The approach first finds an interoperative hypotensive event in the ABP waveform. It then backtracks to earlier in the waveform to extract a 60 second segment representing the waveform feature to use as model input. The figure below shows an example of this approach and is reproduced from the VitalDB example notebook referenced above.

def getSurgeryBoundariesInSeconds(event, debug=False):
eventIndices = np.argwhere(event==event)
# we are looking for the last index where the string contains 'start
lastStart = 0
firstFinish = len(event)-1
# find last start
for idx in eventIndices:
if 'started' in event[idx[0]]:
if debug:
print(event[idx[0]])
print(idx[0])
lastStart = idx[0]
# find first finish
for idx in eventIndices:
if 'finish' in event[idx[0]]:
if debug:
print(event[idx[0]])
print(idx[0])
firstFinish = idx[0]
break
if debug:
print(f'lastStart, firstFinish: {lastStart}, {firstFinish}')
return (lastStart, firstFinish)
def areCaseSegmentsCached(caseid):
seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}/{caseid:04d}"
return os.path.exists(seg_folder) and len(os.listdir(seg_folder)) > 0
def isAbpSegmentValidNumpy(samples, debug=False):
valid = True
if np.isnan(samples).mean() > 0.1:
valid = False
if debug:
print(f">10% NaN")
elif (samples > 200).any():
valid = False
if debug:
print(f"Presence of BP > 200")
elif (samples < 30).any():
valid = False
if debug:
print(f"Presence of BP < 30")
elif np.max(samples) - np.min(samples) < 30:
if debug:
print(f"Max - Min test < 30")
valid = False
elif (np.abs(np.diff(samples)) > 30).any(): # abrupt change -> noise
if debug:
print(f"Abrupt change (noise)")
valid = False
return valid
def isAbpSegmentValid(vf, debug=False):
ABP_ECG_SRATE_HZ = 500
ABP_TRACK_NAME = "SNUADC/ART"
samples = np.array(vf.get_track_samples(ABP_TRACK_NAME, 1/ABP_ECG_SRATE_HZ))
return isAbpSegmentValidNumpy(samples, debug)
def saveCaseSegments(caseid, positiveSegments, negativeSegments, compresslevel=9, debug=False, forceWrite=False):
if len(positiveSegments) == 0 and len(negativeSegments) == 0:
# exit early if no events found
print(f'{caseid}: exit early, no segments to save')
return
# event composition
# predictiveSegmentStart in seconds, predictiveSegmentEnd in seconds, predWindow (0 for negative), abp, ecg, eeg)
# 0start, 1end, 2predwindow, 3abp, 4ecg, 5eeg
seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}/{caseid:04d}"
if not os.path.exists(seg_folder):
# if directory needs to be created, then there are no cached segments
os.mkdir(seg_folder)
else:
if not forceWrite:
# exit early if folder already exists, case already produced
return
# prior to writing files out, clear existing files
for filename in os.listdir(seg_folder):
file_path = os.path.join(seg_folder, filename)
if debug:
print(f'deleting: {file_path}')
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
count_pos_saved = 0
for i in range(0, len(positiveSegments)):
event = positiveSegments[i]
startIndex = event[0]
endIndex = event[1]
predWindow = event[2]
abp = event[3]
#ecg = event[4]
#eeg = event[5]
seg_filename = f"{caseid:04d}_{startIndex}_{predWindow:02d}_True.h5"
seg_fullpath = f"{seg_folder}/{seg_filename}"
if isAbpSegmentValidNumpy(abp, debug):
count_pos_saved += 1
abp = abp.tolist()
ecg = event[4].tolist()
eeg = event[5].tolist()
f = h5py.File(seg_fullpath, "w")
f.create_dataset('abp', data=abp, compression="gzip", compression_opts=compresslevel)
f.create_dataset('ecg', data=ecg, compression="gzip", compression_opts=compresslevel)
f.create_dataset('eeg', data=eeg, compression="gzip", compression_opts=compresslevel)
f.flush()
f.close()
f = None
abp = None
ecg = None
eeg = None
# f.create_dataset('label', data=[1], compression="gzip", compression_opts=compresslevel)
# f.create_dataset('pred_window', data=[event[2]], compression="gzip", compression_opts=compresslevel)
# f.create_dataset('caseid', data=[caseid], compression="gzip", compression_opts=compresslevel)
elif debug:
print(f"{caseid:04d} {predWindow:02d}min {startIndex} starttime = ignored, segment validity issues")
count_neg_saved = 0
for i in range(0, len(negativeSegments)):
event = negativeSegments[i]
startIndex = event[0]
endIndex = event[1]
predWindow = event[2]
abp = event[3]
#ecg = event[4]
#eeg = event[5]
seg_filename = f"{caseid:04d}_{startIndex}_0_False.h5"
seg_fullpath = f"{seg_folder}/{seg_filename}"
if isAbpSegmentValidNumpy(abp, debug):
count_neg_saved += 1
abp = abp.tolist()
ecg = event[4].tolist()
eeg = event[5].tolist()
f = h5py.File(seg_fullpath, "w")
f.create_dataset('abp', data=abp, compression="gzip", compression_opts=compresslevel)
f.create_dataset('ecg', data=ecg, compression="gzip", compression_opts=compresslevel)
f.create_dataset('eeg', data=eeg, compression="gzip", compression_opts=compresslevel)
f.flush()
f.close()
f = None
abp = None
ecg = None
eeg = None
# f.create_dataset('label', data=[0], compression="gzip", compression_opts=compresslevel)
# f.create_dataset('pred_window', data=[0], compression="gzip", compression_opts=compresslevel)
# f.create_dataset('caseid', data=[caseid], compression="gzip", compression_opts=compresslevel)
elif debug:
print(f"{caseid:04d} CleanWindow {startIndex} starttime = ignored, segment validity issues")
if count_neg_saved == 0 and count_pos_saved == 0:
print(f'{caseid}: nothing saved, all segments filtered')
# Generate hypotensive events
# Hypotensive events are defined as a 1-minute interval with sustained ABP of less than 65 mmHg
# Note: Hypotensive events should be at least 20 minutes apart to minimize potential residual effects from previous events
# Generate hypotension non-events
# To sample non-events, 30-minute segments where the ABP was above 75 mmHG were selected, and then
# three one-minute samples of each waveform were obtained from the middle of the segment
# both occur in extract_segments
#VITAL_EXTRACTED_SEGMENTS
def extract_segments(cases_of_interest_idx, debug=False, checkCache=True, forceWrite=False, returnSegments=False):
# Sampling rate for ABP and ECG, Hz. These rates should be the same. Default = 500
ABP_ECG_SRATE_HZ = 500
# Sampling rate for EEG. Default = 128
EEG_SRATE_HZ = 128
# Final dataset for training and testing the model.
positiveSegmentsMap = {}
negativeSegmentsMap = {}
iohEventsMap = {}
cleanEventsMap = {}
# Process each case and extract segments. For each segment identify presence of an event in the label zone.
count_cases = len(cases_of_interest_idx)
#for case_count, caseid in tqdm(enumerate(cases_of_interest_idx), total=count_cases):
for case_count, caseid in enumerate(cases_of_interest_idx):
if debug:
print(f'Loading case: {caseid:04d}, ({case_count + 1} of {count_cases})')
if checkCache and areCaseSegmentsCached(caseid):
if debug:
print(f'Skipping case: {caseid:04d}, already cached')
# skip records we've already cached
continue
# read the arterial waveform
(abp, ecg, eeg, event) = get_track_data(caseid)
if debug:
print(f'Length of {TRACK_NAMES[0]}: {abp.shape[0]}')
print(f'Length of {TRACK_NAMES[1]}: {ecg.shape[0]}')
print(f'Length of {TRACK_NAMES[2]}: {eeg.shape[0]}')
(startInSeconds, endInSeconds) = getSurgeryBoundariesInSeconds(event)
if debug:
print(f"Event markers indicate that surgery begins at {startInSeconds}s and ends at {endInSeconds}s.")
track_length_seconds = int(len(abp) / ABP_ECG_SRATE_HZ)
if debug:
print(f"Processing case {caseid} with length {track_length_seconds}s")
# check if the ABP segment in the surgery window is valid
if debug:
isSurgerySegmentValid = isAbpSegmentValidNumpy(abp[startInSeconds:endInSeconds])
print(f'{caseid}: surgery segment valid: {isSurgerySegmentValid}')
iohEvents = []
cleanEvents = []
i = 0
started = False
eofReached = False
trackStartIndex = None
# set i pointer (which operates in seconds) to start marker for surgery
i = startInSeconds
# FIRST PASS
# in the first forward pass, we are going to identify the start/end boundaries of all IOH events within the case
while i < track_length_seconds - 60 and i < endInSeconds:
segmentStart = None
segmentEnd = None
segFound = False
# look forward one minute
abpSeg = abp[i * ABP_ECG_SRATE_HZ:(i + 60) * ABP_ECG_SRATE_HZ]
# roll forward until we hit a one minute window where mean ABP >= 65 so we know leads are connected and it's tracking
if not started:
if np.nanmean(abpSeg) >= 65:
started = True
trackStartIndex = i
# if we're started and mean abp for the window is <65, we are starting a new IOH event
elif np.nanmean(abpSeg) < 65:
segmentStart = i
# now seek forward to find end of event, perpetually checking the lats minute of the IOH event
for j in range(i + 60, track_length_seconds):
# look backward one minute
abpSegForward = abp[(j - 60) * ABP_ECG_SRATE_HZ:j * ABP_ECG_SRATE_HZ]
if np.nanmean(abpSegForward) >= 65:
segmentEnd = j - 1
break
if segmentEnd is None:
eofReached = True
else:
# otherwise, end of the IOH segment has been reached, record it
iohEvents.append((segmentStart, segmentEnd))
segFound = True
if debug:
t_abp = abp[segmentStart * ABP_ECG_SRATE_HZ:segmentEnd * ABP_ECG_SRATE_HZ]
isIohSegmentValid = isAbpSegmentValidNumpy(t_abp)
print(f'{caseid}: ioh segment valid: {isIohSegmentValid}, {segmentStart}, {segmentEnd}, {t_abp.shape}')
i += 1
if not started:
continue
elif eofReached:
break
elif segFound:
i = segmentEnd + 1
# SECOND PASS
# in the second forward pass, we are going to identify the start/end boundaries of all non-overlapping 30 minute "clean" windows
# reuse the 'start of signal' index from our first pass
if trackStartIndex is None:
trackStartIndex = startInSeconds
i = trackStartIndex
eofReached = False
while i < track_length_seconds - 1800 and i < endInSeconds:
segmentStart = None
segmentEnd = None
segFound = False
startIndex = i
endIndex = i + 1800
# check to see if this 30 minute window overlaps any IOH events, if so ffwd to end of latest overlapping IOH
overlapFound = False
latestEnd = None
for event in iohEvents:
# case 1: starts during an event
if startIndex >= event[0] and startIndex < event[1]:
latestEnd = event[1]
overlapFound = True
# case 2: ends during an event
elif endIndex >= event[0] and endIndex < event[1]:
latestEnd = event[1]
overlapFound = True
# case 3: event occurs entirely inside of the window
elif startIndex < event[0] and endIndex > event[1]:
latestEnd = event[1]
overlapFound = True
# FFWD if we found an overlap
if overlapFound:
i = latestEnd + 1
continue
# look forward 30 minutes
abpSeg = abp[startIndex * ABP_ECG_SRATE_HZ:endIndex * ABP_ECG_SRATE_HZ]
# if we're started and mean abp for the window is >= 75, we are starting a new clean event
if np.nanmean(abpSeg) >= 75:
overlapFound = False
latestEnd = None
for event in iohEvents:
# case 1: starts during an event
if startIndex >= event[0] and startIndex < event[1]:
latestEnd = event[1]
overlapFound = True
# case 2: ends during an event
elif endIndex >= event[0] and endIndex < event[1]:
latestEnd = event[1]
overlapFound = True
# case 3: event occurs entirely inside of the window
elif startIndex < event[0] and endIndex > event[1]:
latestEnd = event[1]
overlapFound = True
if not overlapFound:
segFound = True
segmentEnd = endIndex
cleanEvents.append((startIndex, endIndex))
if debug:
t_abp = abp[startIndex * ABP_ECG_SRATE_HZ:endIndex * ABP_ECG_SRATE_HZ]
isCleanSegmentValid = isAbpSegmentValidNumpy(t_abp)
print(f'{caseid}: clean segment valid: {isCleanSegmentValid}, {startIndex}, {endIndex}, {t_abp.shape}')
i += 10
if segFound:
i = segmentEnd + 1
if debug:
print(f"IOH Events for case {caseid}: {iohEvents}")
print(f"Clean Events for case {caseid}: {cleanEvents}")
positiveSegments = []
negativeSegments = []
# THIRD PASS
# in the third pass, we will use the collections of ioh event windows to generate our actual extracted segments based on our prediction window (positive labels)
for i in range(0, len(iohEvents)):
if debug:
print(f"Checking event {iohEvents[i]}")
# we want to review current event boundaries, as well as previous event boundaries if available
event = iohEvents[i]
previousEvent = None
if i > 0:
previousEvent = iohEvents[i - 1]
for predWindow in ALL_PREDICTION_WINDOWS:
if debug:
print(f"Checking event {iohEvents[i]} for pred {predWindow}")
iohEventStart = event[0]
predictiveSegmentEnd = event[0] - (predWindow*60)
predictiveSegmentStart = predictiveSegmentEnd - 60
if (predictiveSegmentStart < 0):
# don't rewind before the beginning of the track
if debug:
print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, before beginning")
continue
elif (predictiveSegmentStart < trackStartIndex):
# don't rewind before the beginning of signal in track
if debug:
print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, before track start")
continue
elif previousEvent is not None:
# does this event window come before or during the previous event?
overlapFound = False
# case 1: starts during an event
if predictiveSegmentStart >= previousEvent[0] and predictiveSegmentStart < previousEvent[1]:
overlapFound = True
# case 2: ends during an event
elif iohEventStart >= previousEvent[0] and iohEventStart < previousEvent[1]:
overlapFound = True
# case 3: event occurs entirely inside of the window
elif predictiveSegmentStart < previousEvent[0] and iohEventStart > previousEvent[1]:
overlapFound = True
# do not extract a case if we overlap witha nother IOH
if overlapFound:
if debug:
print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, overlap with earlier segment")
continue
# track the positive segment
positiveSegments.append((predictiveSegmentStart, predictiveSegmentEnd, predWindow,
abp[predictiveSegmentStart*ABP_ECG_SRATE_HZ:predictiveSegmentEnd*ABP_ECG_SRATE_HZ],
ecg[predictiveSegmentStart*ABP_ECG_SRATE_HZ:predictiveSegmentEnd*ABP_ECG_SRATE_HZ],
eeg[predictiveSegmentStart*EEG_SRATE_HZ:predictiveSegmentEnd*EEG_SRATE_HZ]))
# FOURTH PASS
# in the fourth and final pass, we will use the collections of clean event windows to generate our actual extracted segments based (negative labels)
for i in range(0, len(cleanEvents)):
# everything will be 30 minutes long at least
event = cleanEvents[i]
# choose sample 1 @ 10 minutes
# choose sample 2 @ 15 minutes
# choose sample 3 @ 20 minutes
timeAtTen = event[0] + 600
timeAtFifteen = event[0] + 900
timeAtTwenty = event[0] + 1200
negativeSegments.append((timeAtTen, timeAtTen + 60, 0,
abp[timeAtTen*ABP_ECG_SRATE_HZ:(timeAtTen + 60)*ABP_ECG_SRATE_HZ],
ecg[timeAtTen*ABP_ECG_SRATE_HZ:(timeAtTen + 60)*ABP_ECG_SRATE_HZ],
eeg[timeAtTen*EEG_SRATE_HZ:(timeAtTen + 60)*EEG_SRATE_HZ]))
negativeSegments.append((timeAtFifteen, timeAtFifteen + 60, 0,
abp[timeAtFifteen*ABP_ECG_SRATE_HZ:(timeAtFifteen + 60)*ABP_ECG_SRATE_HZ],
ecg[timeAtFifteen*ABP_ECG_SRATE_HZ:(timeAtFifteen + 60)*ABP_ECG_SRATE_HZ],
eeg[timeAtFifteen*EEG_SRATE_HZ:(timeAtFifteen + 60)*EEG_SRATE_HZ]))
negativeSegments.append((timeAtTwenty, timeAtTwenty + 60, 0,
abp[timeAtTwenty*ABP_ECG_SRATE_HZ:(timeAtTwenty + 60)*ABP_ECG_SRATE_HZ],
ecg[timeAtTwenty*ABP_ECG_SRATE_HZ:(timeAtTwenty + 60)*ABP_ECG_SRATE_HZ],
eeg[timeAtTwenty*EEG_SRATE_HZ:(timeAtTwenty + 60)*EEG_SRATE_HZ]))
if returnSegments:
positiveSegmentsMap[caseid] = positiveSegments
negativeSegmentsMap[caseid] = negativeSegments
iohEventsMap[caseid] = iohEvents
cleanEventsMap[caseid] = cleanEvents
saveCaseSegments(caseid, positiveSegments, negativeSegments, 9, debug=debug, forceWrite=forceWrite)
#if debug:
print(f'{caseid}: positiveSegments: {len(positiveSegments)}, negativeSegments: {len(negativeSegments)}')
return positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap
Ensure that all needed segments are in place for the cases that are being used. If data is already stored on disk this method returns immediately.
print('here')
here
MANUAL_EXTRACT=True
if MANUAL_EXTRACT:
mycoi = cases_of_interest_idx
#mycoi = cases_of_interest_idx[:2800]
#mycoi = [1]
cnt = 0
mod = 0
for ci in mycoi:
cnt += 1
if mod % 100 == 0:
print(f'count processed: {mod}, current case index: {ci}')
try:
p, n, i, c = extract_segments([ci], debug=False, checkCache=True, forceWrite=True, returnSegments=False)
p = None
n = None
i = None
c = None
except:
print(f'error on extract segment: {ci}')
mod += 1
print(f'extracted: {cnt}')
count processed: 0, current case index: 1 count processed: 100, current case index: 198 count processed: 200, current case index: 431 count processed: 300, current case index: 665 724: exit early, no segments to save 724: positiveSegments: 0, negativeSegments: 0 818: exit early, no segments to save 818: positiveSegments: 0, negativeSegments: 0 count processed: 400, current case index: 853 count processed: 500, current case index: 1046 count processed: 600, current case index: 1236 1271: exit early, no segments to save 1271: positiveSegments: 0, negativeSegments: 0 count processed: 700, current case index: 1440 1505: exit early, no segments to save 1505: positiveSegments: 0, negativeSegments: 0 count processed: 800, current case index: 1639 count processed: 900, current case index: 1843 count processed: 1000, current case index: 2049 2218: exit early, no segments to save 2218: positiveSegments: 0, negativeSegments: 0 count processed: 1100, current case index: 2281 count processed: 1200, current case index: 2469 count processed: 1300, current case index: 2665 count processed: 1400, current case index: 2888 count processed: 1500, current case index: 3092 count processed: 1600, current case index: 3279 3413: exit early, no segments to save 3413: positiveSegments: 0, negativeSegments: 0 count processed: 1700, current case index: 3475 3476: exit early, no segments to save 3476: positiveSegments: 0, negativeSegments: 0 3533: exit early, no segments to save 3533: positiveSegments: 0, negativeSegments: 0 count processed: 1800, current case index: 3694 count processed: 1900, current case index: 3887 3992: exit early, no segments to save 3992: positiveSegments: 0, negativeSegments: 0 count processed: 2000, current case index: 4091 4187: nothing saved, all segments filtered 4187: positiveSegments: 0, negativeSegments: 18 count processed: 2100, current case index: 4296 4328: exit early, no segments to save 4328: positiveSegments: 0, negativeSegments: 0 count processed: 2200, current case index: 4509 4648: exit early, no segments to save 4648: positiveSegments: 0, negativeSegments: 0 4703: exit early, no segments to save 4703: positiveSegments: 0, negativeSegments: 0 count processed: 2300, current case index: 4732 4733: exit early, no segments to save 4733: positiveSegments: 0, negativeSegments: 0 4834: nothing saved, all segments filtered 4834: positiveSegments: 3, negativeSegments: 0 4836: nothing saved, all segments filtered 4836: positiveSegments: 11, negativeSegments: 6 count processed: 2400, current case index: 4929 4985: nothing saved, all segments filtered 4985: positiveSegments: 1, negativeSegments: 0 5130: exit early, no segments to save 5130: positiveSegments: 0, negativeSegments: 0 count processed: 2500, current case index: 5142 5175: nothing saved, all segments filtered 5175: positiveSegments: 2, negativeSegments: 0 5327: nothing saved, all segments filtered 5327: positiveSegments: 4, negativeSegments: 12 count processed: 2600, current case index: 5346 5501: exit early, no segments to save 5501: positiveSegments: 0, negativeSegments: 0 count processed: 2700, current case index: 5564 5587: nothing saved, all segments filtered 5587: positiveSegments: 2, negativeSegments: 0 5693: exit early, no segments to save 5693: positiveSegments: 0, negativeSegments: 0 count processed: 2800, current case index: 5771 5908: exit early, no segments to save 5908: positiveSegments: 0, negativeSegments: 0 count processed: 2900, current case index: 5974 6131: nothing saved, all segments filtered 6131: positiveSegments: 2, negativeSegments: 0 count processed: 3000, current case index: 6174 count processed: 3100, current case index: 6372 extracted: 3110
def printAbp(case_id_to_check, plot_invalid_only=False):
vf_path = f'{VITAL_MINI}/{case_id_to_check:04d}_mini.vital'
vf = vitaldb.VitalFile(vf_path)
abp = vf.to_numpy(TRACK_NAMES[0], 1/500)
print(f'Case {case_id_to_check}')
print(f'ABP Shape: {abp.shape}')
print(f'nanmin: {np.nanmin(abp)}')
print(f'nanmean: {np.nanmean(abp)}')
print(f'nanmax: {np.nanmax(abp)}')
is_valid = isAbpSegmentValidNumpy(abp, debug=True)
print(f'valid: {is_valid}')
if plot_invalid_only and is_valid:
return
plt.figure(figsize=(20, 5))
plt_color = 'C0' if is_valid else 'red'
plt.plot(abp, plt_color)
plt.title(f'ABP - Entire Track - Case {case_id_to_check} - {abp.shape[0] / 500} seconds')
plt.axhline(y = 65, color = 'maroon', linestyle = '--')
plt.show()
def printSegments(segmentsMap, case_id_to_check, print_label, normalize=False):
for (x1, x2, r, abp, ecg, eeg) in segmentsMap[case_id_to_check]:
print(f'{print_label}: Case {case_id_to_check}')
print(f'lookback window: {r} min')
print(f'start time: {x1}')
print(f'end time: {x2}')
print(f'length: {x2 - x1} sec')
print(f'ABP Shape: {abp.shape}')
print(f'ECG Shape: {ecg.shape}')
print(f'EEG Shape: {eeg.shape}')
print(f'nanmin: {np.nanmin(abp)}')
print(f'nanmean: {np.nanmean(abp)}')
print(f'nanmax: {np.nanmax(abp)}')
is_valid = isAbpSegmentValidNumpy(abp, debug=True)
print(f'valid: {is_valid}')
# ABP normalization
x_abp = np.copy(abp)
if normalize:
x_abp -= 65
x_abp /= 65
plt.figure(figsize=(20, 5))
plt_color = 'C0' if is_valid else 'red'
plt.plot(x_abp, plt_color)
plt.title('ABP')
plt.axhline(y = 65, color = 'maroon', linestyle = '--')
plt.show()
plt.figure(figsize=(20, 5))
plt.plot(ecg, 'teal')
plt.title('ECG')
plt.show()
plt.figure(figsize=(20, 5))
plt.plot(eeg, 'indigo')
plt.title('EEG')
plt.show()
print()
def printEvents(abp_raw, eventsMap, case_id_to_check, print_label, normalize=False):
for (x1, x2) in eventsMap[case_id_to_check]:
print(f'{print_label}: Case {case_id_to_check}')
print(f'start time: {x1}')
print(f'end time: {x2}')
print(f'length: {x2 - x1} sec')
abp = abp_raw[x1*500:x2*500]
print(f'ABP Shape: {abp.shape}')
print(f'nanmin: {np.nanmin(abp)}')
print(f'nanmean: {np.nanmean(abp)}')
print(f'nanmax: {np.nanmax(abp)}')
is_valid = isAbpSegmentValidNumpy(abp, debug=True)
print(f'valid: {is_valid}')
# ABP normalization
x_abp = np.copy(abp)
if normalize:
x_abp -= 65
x_abp /= 65
plt.figure(figsize=(20, 5))
plt_color = 'C0' if is_valid else 'red'
plt.plot(x_abp, plt_color)
plt.title('ABP')
plt.axhline(y = 65, color = 'maroon', linestyle = '--')
plt.show()
print()
# Check if all ABPs are well formed.
DISPLAY_REALITY_CHECK_ABP=True
DISPLAY_REALITY_CHECK_ABP_FIRST_ONLY=True
if DISPLAY_REALITY_CHECK_ABP:
for case_id_to_check in cases_of_interest_idx:
printAbp(case_id_to_check, plot_invalid_only=False)
if DISPLAY_REALITY_CHECK_ABP_FIRST_ONLY:
break
Case 1 ABP Shape: (5770575, 1) nanmin: -495.6260070800781 nanmean: 78.15251159667969 nanmax: 374.3236389160156 Presence of BP > 200 valid: False
# These are Vital Files removed because of malformed ABP waveforms.
DISPLAY_MALFORMED_ABP=True
DISPLAY_MALFORMED_ABP_FIRST_ONLY=True
if DISPLAY_MALFORMED_ABP:
malformed_case_ids = pd.read_csv('malformed_tracks_filter.csv', header=None, names=['caseid']).set_index('caseid').index
for case_id_to_check in malformed_case_ids:
printAbp(case_id_to_check)
if DISPLAY_MALFORMED_ABP_FIRST_ONLY:
break
Case 3 ABP Shape: (2196524, 1) nanmin: -117.43000030517578 nanmean: 0.6060270667076111 nanmax: 85.98619842529297 Presence of BP < 30 valid: False
DISPLAY_NO_SEGMENTS_CASES=True
DISPLAY_NO_SEGMENTS_CASES_FIRST_ONLY=True
if DISPLAY_NO_SEGMENTS_CASES:
no_segments_case_ids = [3413, 3476, 3533, 3992, 4328, 4648, 4703, 4733, 5130, 5501, 5693, 5908]
for case_id_to_check in no_segments_case_ids:
printAbp(case_id_to_check)
if DISPLAY_NO_SEGMENTS_CASES_FIRST_ONLY:
break
Case 3413 ABP Shape: (3429927, 1) nanmin: -228.025146484375 nanmean: 48.4425163269043 nanmax: 293.3521423339844 >10% NaN valid: False
Generate segment data for one or more cases.
#mycoi = cases_of_interest_idx
mycoi = cases_of_interest_idx[:1]
#mycoi = [1]
positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap = \
extract_segments(mycoi, debug=False, checkCache=False, forceWrite=False, returnSegments=True)
1: positiveSegments: 12, negativeSegments: 9
Select a specific case to check.
case_id_to_check = cases_of_interest_idx[0]
#case_id_to_check = 1
print(case_id_to_check)
1
print((
len(positiveSegmentsMap[case_id_to_check]),
len(negativeSegmentsMap[case_id_to_check]),
len(iohEventsMap[case_id_to_check]),
len(cleanEventsMap[case_id_to_check])
))
(12, 9, 7, 3)
printAbp(case_id_to_check)
Case 1 ABP Shape: (5770575, 1) nanmin: -495.6260070800781 nanmean: 78.15251159667969 nanmax: 374.3236389160156 Presence of BP > 200 valid: False
printSegments(positiveSegmentsMap, case_id_to_check, 'Positive Segment - IOH Event', normalize=False)
Positive Segment - IOH Event: Case 1 lookback window: 3 min start time: 1548 end time: 1608 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 46.487884521484375 nanmean: 73.00869750976562 nanmax: 113.63497924804688 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 5 min start time: 1428 end time: 1488 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 41.550628662109375 nanmean: 74.47395324707031 nanmax: 128.44686889648438 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 10 min start time: 1128 end time: 1188 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 53.400115966796875 nanmean: 88.63211059570312 nanmax: 135.35903930664062 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 15 min start time: 828 end time: 888 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 23.776397705078125 nanmean: 108.88127136230469 nanmax: 182.75698852539062 Presence of BP < 30 valid: False
Positive Segment - IOH Event: Case 1 lookback window: 3 min start time: 3873 end time: 3933 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 46.487884521484375 nanmean: 75.3544692993164 nanmax: 124.49703979492188 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 5 min start time: 3753 end time: 3813 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 45.500457763671875 nanmean: 73.97709655761719 nanmax: 122.52212524414062 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 10 min start time: 3453 end time: 3513 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 52.412628173828125 nanmean: 86.52787780761719 nanmax: 148.19595336914062 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 15 min start time: 3153 end time: 3213 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 58.337371826171875 nanmean: 100.94121551513672 nanmax: 165.97018432617188 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 3 min start time: 8856 end time: 8916 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 64.26211547851562 nanmean: 97.06536102294922 nanmax: 157.08309936523438 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 5 min start time: 8736 end time: 8796 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 69.19943237304688 nanmean: 105.55238342285156 nanmax: 163.00784301757812 valid: True
Positive Segment - IOH Event: Case 1 lookback window: 10 min start time: 8436 end time: 8496 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: -88.793701171875 nanmean: 130.62982177734375 nanmax: 305.2016296386719 Presence of BP > 200 valid: False
Positive Segment - IOH Event: Case 1 lookback window: 15 min start time: 8136 end time: 8196 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 62.287200927734375 nanmean: 92.04357147216797 nanmax: 138.32138061523438 valid: True
printSegments(negativeSegmentsMap, case_id_to_check, 'Negative Segment - Non-Event', normalize=False)
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 5951 end time: 6011 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 52.412628173828125 nanmean: 76.35643005371094 nanmax: 120.54721069335938 valid: True
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 6251 end time: 6311 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 54.387542724609375 nanmean: 77.73150634765625 nanmax: 120.54721069335938 valid: True
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 6551 end time: 6611 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 58.337371826171875 nanmean: 85.06976318359375 nanmax: 133.38412475585938 valid: True
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 7752 end time: 7812 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 55.375030517578125 nanmean: 80.11844635009766 nanmax: 130.42178344726562 valid: True
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 8052 end time: 8112 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 60.312286376953125 nanmean: 88.32589721679688 nanmax: 134.37161254882812 valid: True
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 8352 end time: 8412 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 68.21194458007812 nanmean: 182.59963989257812 nanmax: 368.3988952636719 Presence of BP > 200 valid: False
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 10104 end time: 10164 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 48.462799072265625 nanmean: 72.81173706054688 nanmax: 115.60989379882812 valid: True
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 10404 end time: 10464 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: -7.822235107421875 nanmean: 106.73753356933594 nanmax: 236.07968139648438 Presence of BP > 200 valid: False
Negative Segment - Non-Event: Case 1 lookback window: 0 min start time: 10704 end time: 10764 length: 60 sec ABP Shape: (30000,) ECG Shape: (30000,) EEG Shape: (7680,) nanmin: 110.67263793945312 nanmean: 172.22396850585938 nanmax: 239.04202270507812 Presence of BP > 200 valid: False
tmp_vf_path = f'{VITAL_MINI}/{case_id_to_check:04d}_mini.vital'
tmp_vf = vitaldb.VitalFile(tmp_vf_path)
tmp_abp = tmp_vf.to_numpy(TRACK_NAMES[0], 1/500)
printEvents(tmp_abp, iohEventsMap, case_id_to_check, 'IOH Event Segment', normalize=False)
IOH Event Segment: Case 1 start time: 1788 end time: 1849 length: 61 sec ABP Shape: (30500, 1) nanmin: 32.663482666015625 nanmean: 64.93988037109375 nanmax: 123.50955200195312 valid: True
IOH Event Segment: Case 1 start time: 1850 end time: 2113 length: 263 sec ABP Shape: (131500, 1) nanmin: 37.600799560546875 nanmean: 63.139060974121094 nanmax: 101.78549194335938 valid: True
IOH Event Segment: Case 1 start time: 2314 end time: 2375 length: 61 sec ABP Shape: (30500, 1) nanmin: -262.5861511230469 nanmean: 65.14369201660156 nanmax: 343.7124938964844 Presence of BP > 200 valid: False
IOH Event Segment: Case 1 start time: 4113 end time: 4199 length: 86 sec ABP Shape: (43000, 1) nanmin: 22.788909912109375 nanmean: 65.0725326538086 nanmax: 153.13327026367188 Presence of BP < 30 valid: False
IOH Event Segment: Case 1 start time: 4261 end time: 5350 length: 1089 sec ABP Shape: (544500, 1) nanmin: 36.613311767578125 nanmean: 60.451026916503906 nanmax: 110.67263793945312 valid: True
IOH Event Segment: Case 1 start time: 9096 end time: 9156 length: 60 sec ABP Shape: (30000, 1) nanmin: 40.563140869140625 nanmean: 64.9837646484375 nanmax: 108.69772338867188 valid: True
IOH Event Segment: Case 1 start time: 9157 end time: 9503 length: 346 sec ABP Shape: (173000, 1) nanmin: 39.575714111328125 nanmean: 62.33021545410156 nanmax: 104.74789428710938 valid: True
printEvents(tmp_abp, cleanEventsMap, case_id_to_check, 'Clean Event Segment', normalize=False)
Clean Event Segment: Case 1 start time: 5351 end time: 7151 length: 1800 sec ABP Shape: (900000, 1) nanmin: 40.563140869140625 nanmean: 84.04818725585938 nanmax: 151.15835571289062 valid: True
Clean Event Segment: Case 1 start time: 7152 end time: 8952 length: 1800 sec ABP Shape: (900000, 1) nanmin: -495.6260070800781 nanmean: 99.71124267578125 nanmax: 368.3988952636719 Presence of BP > 200 valid: False
Clean Event Segment: Case 1 start time: 9504 end time: 11304 length: 1800 sec ABP Shape: (900000, 1) nanmin: -49.295440673828125 nanmean: 83.3201675415039 nanmax: 346.6748352050781 Presence of BP > 200 valid: False
# free memory
tmp_abp = None
def get_segment_attributes_from_filename(file_path):
pieces = os.path.basename(file_path).split('_')
case = int(pieces[0])
startX = int(pieces[1])
predWindow = int(pieces[2])
label = pieces[3].replace('.h5', '')
return (case, startX, predWindow, label)
count_negative_samples = 0
count_positive_samples = 0
samples = []
from glob import glob
seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}"
filenames = [y for x in os.walk(seg_folder) for y in glob(os.path.join(x[0], '*.h5'))]
for filename in filenames:
(case, start_x, pred_window, label) = get_segment_attributes_from_filename(filename)
#print((case, start_x, pred_window, label))
# only load cases for cases of interest; this folder could have segments for hundreds of cases
if case not in cases_of_interest_idx:
continue
#PREDICTION_WINDOW = 3
if pred_window == 0 or pred_window == PREDICTION_WINDOW or PREDICTION_WINDOW == 'ALL':
#print((case, start_x, pred_window, label))
if label == 'True':
count_positive_samples += 1
else:
count_negative_samples += 1
sample = (filename, label)
samples.append(sample)
print()
print(f"samples loaded: {len(samples):5} ")
print(f'count negative samples: {count_negative_samples:5}')
print(f'count positive samples: {count_positive_samples:5}')
samples loaded: 62262 count negative samples: 37572 count positive samples: 24690
# Divide by cases
sample_cases = defaultdict(lambda: [])
for fn, _ in samples:
(case, start_x, pred_window, label) = get_segment_attributes_from_filename(fn)
sample_cases[case].append((fn, label))
# understand any missing cases of interest
sample_cases_idx = pd.Index(sample_cases.keys())
missing_case_ids = cases_of_interest_idx.difference(sample_cases_idx)
print(f'cases with no samples: {missing_case_ids.shape[0]}')
print(f' {missing_case_ids}')
print()
# Split data into training, validation, and test sets
# Use 6:1:3 ratio and prevent samples from a single case from being split across different sets
# Note: number of samples at each time point is not the same, because the first event can occur before the 3/5/10/15 minute mark
# Set target sizes
train_ratio = 0.6
val_ratio = 0.1
test_ratio = 1 - train_ratio - val_ratio # ensure ratios sum to 1
# Split samples into train and other
sample_cases_train, sample_cases_other = train_test_split(list(sample_cases.keys()), test_size=(1 - train_ratio), random_state=RANDOM_SEED)
# Split other into val and test
sample_cases_val, sample_cases_test = train_test_split(sample_cases_other, test_size=(test_ratio / (1 - train_ratio)), random_state=RANDOM_SEED)
# Check how many samples are in each set
print(f'Train/Val/Test Summary by Cases')
print(f"Train cases: {len(sample_cases_train):5}, ({len(sample_cases_train) / len(sample_cases):.2%})")
print(f"Val cases: {len(sample_cases_val):5}, ({len(sample_cases_val) / len(sample_cases):.2%})")
print(f"Test cases: {len(sample_cases_test):5}, ({len(sample_cases_test) / len(sample_cases):.2%})")
print(f"Total cases: {(len(sample_cases_train) + len(sample_cases_val) + len(sample_cases_test)):5}")
cases with no samples: 25
Index([ 724, 818, 1271, 1505, 2218, 3413, 3476, 3533, 3992, 4187, 4328, 4648,
4703, 4733, 4834, 4836, 4985, 5130, 5175, 5327, 5501, 5587, 5693, 5908,
6131],
dtype='int64')
Train/Val/Test Summary by Cases
Train cases: 1851, (60.00%)
Val cases: 308, (9.98%)
Test cases: 926, (30.02%)
Total cases: 3085
sample_cases_train = set(sample_cases_train)
sample_cases_val = set(sample_cases_val)
sample_cases_test = set(sample_cases_test)
samples_train = []
samples_val = []
samples_test = []
for cid, segs in sample_cases.items():
if cid in sample_cases_train:
for seg in segs:
samples_train.append(seg)
if cid in sample_cases_val:
for seg in segs:
samples_val.append(seg)
if cid in sample_cases_test:
for seg in segs:
samples_test.append(seg)
# Check how many samples are in each set
print(f'Train/Val/Test Summary by Events')
print(f"Train events: {len(samples_train):5}, ({len(samples_train) / len(samples):.2%})")
print(f"Val events: {len(samples_val):5}, ({len(samples_val) / len(samples):.2%})")
print(f"Test events: {len(samples_test):5}, ({len(samples_test) / len(samples):.2%})")
print(f"Total events: {(len(samples_train) + len(samples_val) + len(samples_test)):5}")
Train/Val/Test Summary by Events Train events: 37097, (59.58%) Val events: 6075, (9.76%) Test events: 19090, (30.66%) Total events: 62262
PRINT_ALL_CASE_SPLIT_DETAILS = False
case_to_sample_distribution = defaultdict(lambda: {'train': [0, 0], 'val': [0, 0], 'test': [0, 0]})
def populate_case_to_sample_distribution(mysamples, idx):
neg = 0
pos = 0
for fn, _ in mysamples:
(case, start_x, pred_window, label) = get_segment_attributes_from_filename(fn)
slot = 0 if label == 'False' else 1
case_to_sample_distribution[case][idx][slot] += 1
if slot == 0:
neg += 1
else:
pos += 1
return (neg, pos)
train_neg, train_pos = populate_case_to_sample_distribution(samples_train, 'train')
val_neg, val_pos = populate_case_to_sample_distribution(samples_val, 'val')
test_neg, test_pos = populate_case_to_sample_distribution(samples_test, 'test')
print(f'Total Cases Present: {len(case_to_sample_distribution):5}')
print()
train_tot = train_pos + train_neg
val_tot = val_pos + val_neg
test_tot = test_pos + test_neg
print(f'Train: P: {train_pos:5} ({(train_pos/train_tot):.2}), N: {train_neg:5} ({(train_neg/train_tot):.2})')
print(f'Val: P: {val_pos:5} ({(val_pos/val_tot):.2}), N: {val_neg:5} ({(val_neg/val_tot):.2})')
print(f'Test: P: {test_pos:5} ({(test_pos/test_tot):.2}), N: {test_neg:5} ({(test_neg/test_tot):.2})')
print()
total_pos = train_pos + val_pos + test_pos
total_neg = train_neg + val_neg + test_neg
total = total_pos + total_neg
print(f'P/N Ratio: {(total_pos)}:{(total_neg)}')
print(f'P Percent: {(total_pos/total):.2}')
print(f'N Percent: {(total_neg/total):.2}')
print()
if PRINT_ALL_CASE_SPLIT_DETAILS:
for ci in sorted(case_to_sample_distribution.keys()):
print(f'{ci}: {case_to_sample_distribution[ci]}')
Total Cases Present: 3085 Train: P: 15113 (0.41), N: 21984 (0.59) Val: P: 2535 (0.42), N: 3540 (0.58) Test: P: 7042 (0.37), N: 12048 (0.63) P/N Ratio: 24690:37572 P Percent: 0.4 N Percent: 0.6
# Create vitalDataset class
class vitalDataset(Dataset):
def __init__(self, samples, normalize_abp=False):
self.samples = samples
self.normalize_abp = normalize_abp
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
# Get metadata for this event
segment = self.samples[idx]
file_path = segment[0]
label = (segment[1] == "True" or segment[1] == "True.vital")
(abp, ecg, eeg) = get_segment_data(file_path)
if abp is None or eeg is None or ecg is None:
return (np.zeros(30000), np.zeros(30000), np.zeros(7680), 0)
if self.normalize_abp:
abp -= 65
abp /= 65
return abp, ecg, eeg, label
NORMALIZE_ABP = False
train_dataset = vitalDataset(samples_train, NORMALIZE_ABP)
val_dataset = vitalDataset(samples_val, NORMALIZE_ABP)
test_dataset = vitalDataset(samples_test, NORMALIZE_ABP)
BATCH_SIZE = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
Check if data can be easily classified using non-deep learning methods. Create a balanced sample of IOH and non-IOH events and use a simple classifier to see if the data can be easily separated. Datasets which can be easily separated by non-deep learning methods should also be easily classified by deep learning models.
MAX_CLASSIFICATION_SAMPLES = 250
MAX_SAMPLE_SIZE = 1600
classification_sample_size = MAX_SAMPLE_SIZE if len(samples) >= MAX_SAMPLE_SIZE else len(samples)
classification_samples = random.sample(samples, classification_sample_size)
positive_samples = []
negative_samples = []
for sample in classification_samples:
(sampleAbp, sampleEcg, sampleEeg) = get_segment_data(sample[0])
if sample[1] == "True":
positive_samples.append([sample[0], True, sampleAbp, sampleEcg, sampleEeg])
else:
negative_samples.append([sample[0], False, sampleAbp, sampleEcg, sampleEeg])
positive_samples = pd.DataFrame(positive_samples, columns=["file_path", "segment_label", "segment_abp", "segment_ecg", "segment_eeg"])
negative_samples = pd.DataFrame(negative_samples, columns=["file_path", "segment_label", "segment_abp", "segment_ecg", "segment_eeg"])
total_to_sample_pos = MAX_CLASSIFICATION_SAMPLES if len(positive_samples) >= MAX_CLASSIFICATION_SAMPLES else len(positive_samples)
total_to_sample_neg = MAX_CLASSIFICATION_SAMPLES if len(negative_samples) >= MAX_CLASSIFICATION_SAMPLES else len(negative_samples)
# Select up to 150 random samples where segment_label is True
positive_samples = positive_samples.sample(total_to_sample_pos, random_state=RANDOM_SEED)
# Select up to 150 random samples where segment_label is False
negative_samples = negative_samples.sample(total_to_sample_neg, random_state=RANDOM_SEED)
print(f'positive_samples: {len(positive_samples)}')
print(f'negative_samples: {len(negative_samples)}')
# Combine the positive and negative samples
samples_balanced = pd.concat([positive_samples, negative_samples])
positive_samples: 250 negative_samples: 250
Define function to build data for study. Each waveform field can be enabled or disabled:
def get_x_y(samples, use_abp, use_ecg, use_eeg):
# Create X and y, using data from `samples_balanced` and the `use_abp`, `use_ecg`, and `use_eeg` variables
X = []
y = []
for i in range(len(samples)):
row = samples.iloc[i]
sample = np.array([])
if use_abp:
if len(row['segment_abp']) != 30000:
print(len(row['segment_abp']))
sample = np.append(sample, row['segment_abp'])
if use_ecg:
if len(row['segment_ecg']) != 30000:
print(len(row['segment_ecg']))
sample = np.append(sample, row['segment_ecg'])
if use_eeg:
if len(row['segment_eeg']) != 7680:
print(len(row['segment_eeg']))
sample = np.append(sample, row['segment_eeg'])
X.append(sample)
# Convert the label from boolean to 0 or 1
y.append(int(row['segment_label']))
return X, y
Define KNN run. This is configurable to enable or disable different data channels so that we can study them individually or together:
N_NEIGHBORS = 20
def run_knn(samples, use_abp, use_ecg, use_eeg):
# Get samples
X,y = get_x_y(samples, use_abp, use_ecg, use_eeg)
# Split samples into train and val
knn_X_train, knn_X_test, knn_y_train, knn_y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_SEED)
# Normalize the data
scaler = StandardScaler()
scaler.fit(knn_X_train)
knn_X_train = scaler.transform(knn_X_train)
knn_X_test = scaler.transform(knn_X_test)
# Initialize the KNN classifier
knn = KNeighborsClassifier(n_neighbors=N_NEIGHBORS)
# Train the KNN classifier
knn.fit(knn_X_train, knn_y_train)
# Make predictions on the test set
knn_y_pred = knn.predict(knn_X_test)
# Evaluate the KNN classifier
print(f"ABP: {use_abp}, ECG: {use_ecg}, EEG: {use_eeg}")
print(f"Confusion matrix:\n{confusion_matrix(knn_y_test, knn_y_pred)}")
print(f"Classification report:\n{classification_report(knn_y_test, knn_y_pred)}")
Study each waveform independently, then ABP+EEG (which had best results in paper), and ABP+ECG+EEG:
run_knn(samples_balanced, use_abp=True, use_ecg=False, use_eeg=False)
run_knn(samples_balanced, use_abp=False, use_ecg=True, use_eeg=False)
run_knn(samples_balanced, use_abp=False, use_ecg=False, use_eeg=True)
run_knn(samples_balanced, use_abp=True, use_ecg=False, use_eeg=True)
run_knn(samples_balanced, use_abp=True, use_ecg=True, use_eeg=True)
ABP: True, ECG: False, EEG: False
Confusion matrix:
[[38 16]
[10 36]]
Classification report:
precision recall f1-score support
0 0.79 0.70 0.75 54
1 0.69 0.78 0.73 46
accuracy 0.74 100
macro avg 0.74 0.74 0.74 100
weighted avg 0.75 0.74 0.74 100
ABP: False, ECG: True, EEG: False
Confusion matrix:
[[52 2]
[46 0]]
Classification report:
precision recall f1-score support
0 0.53 0.96 0.68 54
1 0.00 0.00 0.00 46
accuracy 0.52 100
macro avg 0.27 0.48 0.34 100
weighted avg 0.29 0.52 0.37 100
ABP: False, ECG: False, EEG: True
Confusion matrix:
[[ 4 50]
[ 6 40]]
Classification report:
precision recall f1-score support
0 0.40 0.07 0.12 54
1 0.44 0.87 0.59 46
accuracy 0.44 100
macro avg 0.42 0.47 0.36 100
weighted avg 0.42 0.44 0.34 100
ABP: True, ECG: False, EEG: True
Confusion matrix:
[[42 12]
[15 31]]
Classification report:
precision recall f1-score support
0 0.74 0.78 0.76 54
1 0.72 0.67 0.70 46
accuracy 0.73 100
macro avg 0.73 0.73 0.73 100
weighted avg 0.73 0.73 0.73 100
ABP: True, ECG: True, EEG: True
Confusion matrix:
[[42 12]
[17 29]]
Classification report:
precision recall f1-score support
0 0.71 0.78 0.74 54
1 0.71 0.63 0.67 46
accuracy 0.71 100
macro avg 0.71 0.70 0.71 100
weighted avg 0.71 0.71 0.71 100
Based on the data above, the ABP data alone is strongly predictive based on the macro average F1-score of 0.90. The ECG and EEG data are weakly predictive with F1 scores of 0.33 and 0.64, respectively. The ABP+EEG data is also strongly predictive with an F1 score of 0.88, and ABP+ECG+EEG data somewhat predictive with an F1 score of 0.79.
Models based on ABP data alone, or ABP+EEG data are expected to train easily with good performance. The other signals appear to mostly add noise and are not strongly predictive. This agrees with the results from the paper.
Define t-SNE run. This is configurable to enable or disable different data channels so that we can study them individually or together:
def run_tsne(samples, use_abp, use_ecg, use_eeg):
# Get samples
X,y = get_x_y(samples, use_abp, use_ecg, use_eeg)
# Convert X and y to numpy arrays
X = np.array(X)
y = np.array(y)
# Run t-SNE on the samples
tsne = TSNE(n_components=len(np.unique(y)), random_state=RANDOM_SEED)
X_tsne = tsne.fit_transform(X)
# Create a scatter plot of the t-SNE representation
plt.figure(figsize=(16, 9))
plt.title(f"use_abp={use_abp}, use_ecg={use_ecg}, use_eeg={use_eeg}")
for i, label in enumerate(set(y)):
plt.scatter(X_tsne[y == label, 0], X_tsne[y == label, 1], label=label)
plt.legend()
plt.show()
Study each waveform independently, then ABP+EEG (which had best results in paper), and ABP+ECG+EEG:
run_tsne(samples_balanced, use_abp=True, use_ecg=False, use_eeg=False)
run_tsne(samples_balanced, use_abp=False, use_ecg=True, use_eeg=False)
run_tsne(samples_balanced, use_abp=False, use_ecg=False, use_eeg=True)
run_tsne(samples_balanced, use_abp=True, use_ecg=False, use_eeg=True)
run_tsne(samples_balanced, use_abp=True, use_ecg=True, use_eeg=True)
Based on the plots above, it appears that ABP alone, ABP+EEG and ABP+ECG+EEG are somewhat separable, though with outliers, and should be trainable by our model. The ECG and EEG data are not easily separable from the other data. This agrees with the results from the paper.
# cleanup
samples_balanced = None
The model implementation is based on the CNN architecture described in Jo Y-Y et al. (2022). It is designed to handle 1, 2, or 3 signal categories simultaneously, allowing for flexible model configurations based on different combinations of physiological signals:
The architecture, as depicted in Figure 2 from the original paper, utilizes a ResNet-based approach tailored for time-series data from different physiological signals. The model architecture is adapted to handle varying input signal frequencies, with specific hyperparameters for each signal type, particularly EEG, due to its distinct characteristics compared to ABP and ECG. A diagram of the model architecture is shown below:
Each input signal is processed through a sequence of 12 7-layer residual blocks, followed by a flattening process and a linear transformation to produce a 32-dimensional feature vector per signal type. These vectors are then concatenated (if multiple signals are used) and passed through two additional linear layers to produce a single output vector, representing the IOH index. A threshold is determined experimentally in order to minimize the differene between the sensitivity and specificity and is applied to this index to perform binary classification for predicting IOH events.
The hyperparameters for the residual blocks are specified in Supplemental Table 1 from the original paper and vary for different signal type.
A forward pass through the model passes through 85 layers before concatenation, followed by two more linear layers and finally a sigmoid activation layer to produce the prediction measure.
Each residual block consists of the following seven layers:
Skip connections are included to aid in gradient flow during training, with optional 1D convolution in the skip connection to align dimensions.
The hyperparameters are detailed in Supplemental Table 1 of the original paper. A screenshot of these hyperparameters is provided for reference below:

Note: Please be aware of a transcription error in the original paper's Supplemental Table 1 for the ECG+ABP configuration in Residual Blocks 11 and 12, where the output size should be 469 6 instead of the reported 496 6.
Our model uses binary cross entropy as the loss function and Adam as the optimizer, consistent with the original study. The learning rate is set at 0.0001, and training is configured to run for up to 100 epochs, with early stopping implemented if no improvement in loss is observed over five consecutive epochs.
# First define the residual block which is reused 12x for each data track for each sample.
# Second define the primary model.
class ResidualBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, size_down: bool = False, ignoreSkipConnection: bool = False) -> None:
super(ResidualBlock, self).__init__()
self.ignoreSkipConnection = ignoreSkipConnection
# calculate the appropriate padding required to ensure expected sequence lengths out of each residual block
padding = int((((stride-1)*in_features)-stride+kernel_size)/2)
self.size_down = size_down
self.bn1 = nn.BatchNorm1d(in_channels)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
self.bn2 = nn.BatchNorm1d(out_channels)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
self.residualConv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
# unclear where in sequence this should take place. Size down expressed in Supplemental table S1
if self.size_down:
pool_padding = (1 if (in_features % 2 > 0) else 0)
self.downsample = nn.MaxPool1d(kernel_size=2, stride=2, padding = pool_padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.bn1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.conv1(out)
if self.size_down:
out = self.downsample(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
if not self.ignoreSkipConnection:
if out.shape != identity.shape:
# run the residual through a convolution when necessary
identity = self.residualConv(identity)
outlen = np.prod(out.shape)
idlen = np.prod(identity.shape)
# downsample when required
if idlen > outlen:
identity = self.downsample(identity)
# match dimensions
identity = identity.reshape(out.shape)
# add the residual
out += identity
return out
class HypotensionCNN(nn.Module):
def __init__(self, useAbp: bool = True, useEeg: bool = False, useEcg: bool = False, maxSixResiduals: bool = False, maxOneResiduals: bool = False, ignoreSkipConnection: bool = False) -> None:
super(HypotensionCNN, self).__init__()
self.useAbp = useAbp
self.useEeg = useEeg
self.useEcg = useEcg
self.maxSixResiduals = maxSixResiduals
self.maxOneResiduals= maxOneResiduals
if useAbp:
if not self.maxOneResiduals and not self.maxSixResiduals:
self.abpBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
self.abpBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
self.abpBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
self.abpBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
self.abpBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
self.abpBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
self.abpBlock7 = ResidualBlock(3750, 1875, 4, 4, 7, 1, True, ignoreSkipConnection)
self.abpBlock8 = ResidualBlock(1875, 1875, 4, 4, 7, 1, False, ignoreSkipConnection)
self.abpBlock9 = ResidualBlock(1875, 938, 4, 4, 7, 1, True, ignoreSkipConnection)
self.abpBlock10 = ResidualBlock(938, 938, 4, 4, 7, 1, False, ignoreSkipConnection)
self.abpBlock11 = ResidualBlock(938, 469, 4, 6, 7, 1, True, ignoreSkipConnection)
self.abpBlock12 = ResidualBlock(469, 469, 6, 6, 7, 1, False, ignoreSkipConnection)
self.abpFc = nn.Linear(6*469, 32)
elif self.maxOneResiduals:
self.abpBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
self.abpFc = nn.Linear(2 * 15000, 32)
elif self.maxSixResiduals:
self.abpBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
self.abpBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
self.abpBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
self.abpBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
self.abpBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
self.abpBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
self.abpFc = nn.Linear(4 * 3750, 32)
if useEcg:
if not self.maxOneResiduals and not self.maxSixResiduals:
self.ecgBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
self.ecgBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
self.ecgBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
self.ecgBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
self.ecgBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
self.ecgBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
self.ecgBlock7 = ResidualBlock(3750, 1875, 4, 4, 7, 1, True, ignoreSkipConnection)
self.ecgBlock8 = ResidualBlock(1875, 1875, 4, 4, 7, 1, False, ignoreSkipConnection)
self.ecgBlock9 = ResidualBlock(1875, 938, 4, 4, 7, 1, True, ignoreSkipConnection)
self.ecgBlock10 = ResidualBlock(938, 938, 4, 4, 7, 1, False, ignoreSkipConnection)
self.ecgBlock11 = ResidualBlock(938, 469, 4, 6, 7, 1, True, ignoreSkipConnection)
self.ecgBlock12 = ResidualBlock(469, 469, 6, 6, 7, 1, False, ignoreSkipConnection)
self.ecgFc = nn.Linear(6 * 469, 32)
elif self.maxOneResiduals:
self.ecgBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
self.ecgFc = nn.Linear(2 * 15000, 32)
elif self.maxSixResiduals:
self.ecgBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
self.ecgBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
self.ecgBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
self.ecgBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
self.ecgBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
self.ecgBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
self.ecgFc = nn.Linear(4 * 3750, 32)
if useEeg:
if not self.maxOneResiduals and not self.maxSixResiduals:
self.eegBlock1 = ResidualBlock(7680, 3840, 1, 2, 7, 1, True, ignoreSkipConnection)
self.eegBlock2 = ResidualBlock(3840, 3840, 2, 2, 7, 1, False, ignoreSkipConnection)
self.eegBlock3 = ResidualBlock(3840, 1920, 2, 2, 7, 1, True, ignoreSkipConnection)
self.eegBlock4 = ResidualBlock(1920, 1920, 2, 2, 7, 1, False, ignoreSkipConnection)
self.eegBlock5 = ResidualBlock(1920, 960, 2, 2, 7, 1, True, ignoreSkipConnection)
self.eegBlock6 = ResidualBlock(960, 960, 2, 4, 7, 1, False, ignoreSkipConnection)
self.eegBlock7 = ResidualBlock(960, 480, 4, 4, 3, 1, True, ignoreSkipConnection)
self.eegBlock8 = ResidualBlock(480, 480, 4, 4, 3, 1, False, ignoreSkipConnection)
self.eegBlock9 = ResidualBlock(480, 240, 4, 4, 3, 1, True, ignoreSkipConnection)
self.eegBlock10 = ResidualBlock(240, 240, 4, 4, 3, 1, False, ignoreSkipConnection)
self.eegBlock11 = ResidualBlock(240, 120, 4, 6, 3, 1, True, ignoreSkipConnection)
self.eegBlock12 = ResidualBlock(120, 120, 6, 6, 3, 1, False, ignoreSkipConnection)
self.eegFc = nn.Linear(6 * 120, 32)
elif self.maxOneResiduals:
self.eegBlock1 = ResidualBlock(7680, 3840, 1, 2, 7, 1, True, ignoreSkipConnection)
self.eegFc = nn.Linear(2 * 3840, 32)
elif self.maxSixResiduals:
self.eegBlock1 = ResidualBlock(7680, 3840, 1, 2, 7, 1, True, ignoreSkipConnection)
self.eegBlock2 = ResidualBlock(3840, 3840, 2, 2, 7, 1, False, ignoreSkipConnection)
self.eegBlock3 = ResidualBlock(3840, 1920, 2, 2, 7, 1, True, ignoreSkipConnection)
self.eegBlock4 = ResidualBlock(1920, 1920, 2, 2, 7, 1, False, ignoreSkipConnection)
self.eegBlock5 = ResidualBlock(1920, 960, 2, 2, 7, 1, True, ignoreSkipConnection)
self.eegBlock6 = ResidualBlock(960, 960, 2, 4, 7, 1, False, ignoreSkipConnection)
self.eegFc = nn.Linear(4 * 960, 32)
concatSize = 0
if useAbp:
concatSize += 32
if useEeg:
concatSize += 32
if useEcg:
concatSize += 32
self.fullLinear1 = nn.Linear(concatSize, 16)
self.fullLinear2 = nn.Linear(16, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, abp: torch.Tensor, eeg: torch.Tensor, ecg: torch.Tensor) -> torch.Tensor:
batchSize = len(abp)
# conditionally operate ABP, EEG, and ECG networks
if self.useAbp:
if self.maxOneResiduals:
abp = self.abpBlock1(abp)
elif self.maxSixResiduals:
abp = self.abpBlock1(abp)
abp = self.abpBlock2(abp)
abp = self.abpBlock3(abp)
abp = self.abpBlock4(abp)
abp = self.abpBlock5(abp)
abp = self.abpBlock6(abp)
elif not self.maxOneResiduals and not self.maxSixResiduals:
abp = self.abpBlock1(abp)
abp = self.abpBlock2(abp)
abp = self.abpBlock3(abp)
abp = self.abpBlock4(abp)
abp = self.abpBlock5(abp)
abp = self.abpBlock6(abp)
abp = self.abpBlock7(abp)
abp = self.abpBlock8(abp)
abp = self.abpBlock9(abp)
abp = self.abpBlock10(abp)
abp = self.abpBlock11(abp)
abp = self.abpBlock12(abp)
totalLen = np.prod(abp.shape)
abp = torch.reshape(abp, (batchSize, int(totalLen / batchSize)))
abp = self.abpFc(abp)
if self.useEeg:
if self.maxOneResiduals:
eeg = self.eegBlock1(eeg)
elif self.maxSixResiduals:
eeg = self.eegBlock1(eeg)
eeg = self.eegBlock2(eeg)
eeg = self.eegBlock3(eeg)
eeg = self.eegBlock4(eeg)
eeg = self.eegBlock5(eeg)
eeg = self.eegBlock6(eeg)
elif not self.maxOneResiduals and not self.maxSixResiduals:
eeg = self.eegBlock1(eeg)
eeg = self.eegBlock2(eeg)
eeg = self.eegBlock3(eeg)
eeg = self.eegBlock4(eeg)
eeg = self.eegBlock5(eeg)
eeg = self.eegBlock6(eeg)
eeg = self.eegBlock7(eeg)
eeg = self.eegBlock8(eeg)
eeg = self.eegBlock9(eeg)
eeg = self.eegBlock10(eeg)
eeg = self.eegBlock11(eeg)
eeg = self.eegBlock12(eeg)
totalLen = np.prod(eeg.shape)
eeg = torch.reshape(eeg, (batchSize, int(totalLen / batchSize)))
eeg = self.eegFc(eeg)
if self.useEcg:
if self.maxOneResiduals:
ecg = self.ecgBlock1(ecg)
elif self.maxSixResiduals:
ecg = self.ecgBlock1(ecg)
ecg = self.ecgBlock2(ecg)
ecg = self.ecgBlock3(ecg)
ecg = self.ecgBlock4(ecg)
ecg = self.ecgBlock5(ecg)
ecg = self.ecgBlock6(ecg)
elif not self.maxOneResiduals and not self.maxSixResiduals:
ecg = self.ecgBlock1(ecg)
ecg = self.ecgBlock2(ecg)
ecg = self.ecgBlock3(ecg)
ecg = self.ecgBlock4(ecg)
ecg = self.ecgBlock5(ecg)
ecg = self.ecgBlock6(ecg)
ecg = self.ecgBlock7(ecg)
ecg = self.ecgBlock8(ecg)
ecg = self.ecgBlock9(ecg)
ecg = self.ecgBlock10(ecg)
ecg = self.ecgBlock11(ecg)
ecg = self.ecgBlock12(ecg)
totalLen = np.prod(ecg.shape)
ecg = torch.reshape(ecg, (batchSize, int(totalLen / batchSize)))
ecg = self.ecgFc(ecg)
# concatenation
merged = None
if self.useAbp and self.useEeg and self.useEcg:
merged = torch.cat((abp, eeg, ecg), dim=1)
elif self.useAbp and self.useEeg:
merged = torch.cat((abp, eeg), dim=1)
elif self.useAbp and self.useEcg:
merged = torch.cat((abp, ecg), dim=1)
elif self.useEeg and self.useEcg:
merged = torch.cat((eeg, ecg), dim=1)
elif self.useAbp:
merged = abp
elif self.useEeg:
merged = eeg
elif self.useEcg:
merged = ecg
totalLen = np.prod(merged.shape)
merged = torch.reshape(merged, (batchSize, int(totalLen / batchSize)))
out = self.fullLinear1(merged)
out = self.fullLinear2(out)
out = self.sigmoid(out)
out = torch.nan_to_num(out)
return out
As discussed earlier, our model uses binary cross entropy as the loss function and Adam as the optimizer, consistent with the original study. The learning rate is set at 0.0001, and training is configured to run for up to 100 epochs, with early stopping implemented if no improvement in loss is observed over five consecutive epochs.
LEARNING_RATE = 0.0001
PATIENCE=15
useAbp = True
useEeg = False
useEcg = False
# enable only a single ablation
useAblationSixResidualBlocks = False
useAblationOneResidualBlocks = False
useAblationIgnoreSkipConnection = False
# to be composed by checking config booleans
experimentName = "DEFAULT"
# enforce single ablation
if useAblationSixResidualBlocks and useAblationOneResidualBlocks and useAblationIgnoreSkipConnection:
# if all 3 selected, only choose one residual block
useAblationSixResidualBlocks = False
useAblationIgnoreSkipConnection = False
elif useAblationSixResidualBlocks and useAblationOneResidualBlocks:
# if 6 and 1, only choose 1
useAblationSixResidualBlocks = False
elif useAblationSixResidualBlocks and useAblationIgnoreSkipConnection:
# if six and skip, only choose six
useAblationIgnoreSkipConnection = False
elif useAblationOneResidualBlocks and useAblationIgnoreSkipConnection:
# if one and skip, only choose six
useAblationIgnoreSkipConnection = False
if useAbp and useEeg and useEcg:
experimentName = "ABP_EEG_ECG"
elif useAbp and useEeg:
experimentName = "ABP_EEG"
elif useAbp and useEcg:
experimentName = "ABP_ECG"
elif useEeg and useEcg:
experimentName = "EEG_ECG"
elif useAbp:
experimentName = "ABP"
elif useEeg:
experimentName = "EEG"
elif useEcg:
experimentName = "ECG"
if useAblationSixResidualBlocks:
experimentName = f"{experimentName}_ABLATION_SIX_RESIDUAL_BLOCKS"
if useAblationOneResidualBlocks:
experimentName = f"{experimentName}_ABLATION_ONE_RESIDUAL_BLOCK"
if useAblationIgnoreSkipConnection:
experimentName = f"{experimentName}_ABLATION_IGNORE_SKIP_CONNECTION"
experimentName = f"{experimentName}_{PREDICTION_WINDOW}_MINS"
if MAX_CASES is not None:
experimentName = f"{experimentName}_MAX_{MAX_CASES}_CASES"
print(f"Preparing to run experiment titled {experimentName}")
model = HypotensionCNN(useAbp, useEeg, useEcg, useAblationSixResidualBlocks, useAblationOneResidualBlocks, useAblationIgnoreSkipConnection)
loss_func = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if (torch.backends.mps.is_available() and torch.backends.mps.is_built()) else "cpu")
print(f"Using device: {device}")
model = model.to(device)
def train_model_one_iter(model, loss_func, optimizer, train_loader):
model.train()
train_losses = []
for abp, ecg, eeg, label in tqdm(train_loader):
batch = len(abp)
abp = abp.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
ecg = ecg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
eeg = eeg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
label = label.type(torch.float).reshape(batch, 1).to(device)
optimizer.zero_grad()
mdl = model(abp, eeg, ecg)
loss = loss_func(torch.nan_to_num(mdl), label)
loss.backward()
optimizer.step()
train_losses.append(loss.cpu().data.numpy())
return np.mean(train_losses)
def evaluate_model(model, loss_func, val_loader):
model.eval()
val_losses = []
for abp, ecg, eeg, label in tqdm(val_loader):
batch = len(abp)
abp = abp.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
ecg = ecg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
eeg = eeg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
label = label.type(torch.float).reshape(batch, 1).to(device)
mdl = model(abp, eeg, ecg)
loss = loss_func(torch.nan_to_num(mdl), label)
val_losses.append(loss.cpu().data.numpy())
return np.mean(val_losses)
# Training loop
max_epochs = 100
best_epoch = 0
train_losses = []
val_losses = []
best_loss = float('inf')
no_improve_epochs = 0
model_path = os.path.join(VITAL_MODELS, f"{experimentName}.model")
all_models = []
for i in range(max_epochs):
# Train the model and get the training loss
train_loss = train_model_one_iter(model, loss_func, optimizer, train_loader)
train_losses.append(train_loss)
# Calculate validate loss
val_loss = evaluate_model(model, loss_func, val_loader)
val_losses.append(val_loss)
print(f"[{datetime.now()}] Completed epoch {i} with training loss {train_loss:.8f}, validation loss {val_loss:.8f}")
# Save all intermediary models.
tmp_model_path = os.path.join(VITAL_MODELS, f"{experimentName}_{i:04d}.model")
torch.save(model.state_dict(), tmp_model_path)
all_models.append(tmp_model_path)
# Check if validation loss has improved
if val_loss < best_loss:
best_epoch = i
best_loss = val_loss
no_improve_epochs = 0
torch.save(model.state_dict(), model_path)
print(f"Validation loss improved to {val_loss:.8f}. Model saved.")
else:
no_improve_epochs += 1
print(f"No improvement in validation loss. {no_improve_epochs} epochs without improvement.")
# exit early if no improvement in loss over last 'patience' epochs
if no_improve_epochs >= PATIENCE:
print("Early stopping due to no improvement in validation loss.")
break
# Load best model from disk
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path))
print(f"Loaded best model from disk from epoch {best_epoch}.")
else:
print("No saved model found for f{experimentName}.")
model.train(False)
Preparing to run experiment titled ABP_ALL_MINS Using device: mps
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:28<00:00, 3.91it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.92it/s]
[2024-04-29 02:15:20.357495] Completed epoch 0 with training loss 0.60090095, validation loss 0.63873434 Validation loss improved to 0.63873434. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:25<00:00, 3.99it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.04it/s]
[2024-04-29 02:18:04.742882] Completed epoch 1 with training loss 0.59035784, validation loss 0.64428014 No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:25<00:00, 3.98it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.05it/s]
[2024-04-29 02:20:49.455103] Completed epoch 2 with training loss 0.58876944, validation loss 0.65615457 No improvement in validation loss. 2 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:26<00:00, 3.96it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.97it/s]
[2024-04-29 02:23:35.021649] Completed epoch 3 with training loss 0.58622575, validation loss 0.61752570 Validation loss improved to 0.61752570. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:26<00:00, 3.96it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.93it/s]
[2024-04-29 02:26:20.673648] Completed epoch 4 with training loss 0.58470070, validation loss 0.63259184 No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:26<00:00, 3.95it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.91it/s]
[2024-04-29 02:29:07.004208] Completed epoch 5 with training loss 0.58424848, validation loss 0.62557769 No improvement in validation loss. 2 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.04it/s]
[2024-04-29 02:31:50.617867] Completed epoch 6 with training loss 0.58467871, validation loss 0.61844271 No improvement in validation loss. 3 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:23<00:00, 4.04it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.10it/s]
[2024-04-29 02:34:32.726852] Completed epoch 7 with training loss 0.58348274, validation loss 0.63179791 No improvement in validation loss. 4 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 02:37:16.153784] Completed epoch 8 with training loss 0.58116812, validation loss 0.65234005 No improvement in validation loss. 5 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.00it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 02:39:59.758599] Completed epoch 9 with training loss 0.58109468, validation loss 0.68132764 No improvement in validation loss. 6 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 02:42:43.318282] Completed epoch 10 with training loss 0.58063072, validation loss 0.62297595 No improvement in validation loss. 7 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.00it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 02:45:26.894458] Completed epoch 11 with training loss 0.58049780, validation loss 0.61584508 Validation loss improved to 0.61584508. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 02:48:10.341834] Completed epoch 12 with training loss 0.57950145, validation loss 0.67962462 No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 02:50:53.800560] Completed epoch 13 with training loss 0.57805961, validation loss 0.66277754 No improvement in validation loss. 2 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 02:53:37.292114] Completed epoch 14 with training loss 0.57829326, validation loss 0.62205291 No improvement in validation loss. 3 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.06it/s]
[2024-04-29 02:56:20.826870] Completed epoch 15 with training loss 0.57719731, validation loss 0.66759384 No improvement in validation loss. 4 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 02:59:04.075479] Completed epoch 16 with training loss 0.57676178, validation loss 0.61395967 Validation loss improved to 0.61395967. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:01:47.509303] Completed epoch 17 with training loss 0.57576907, validation loss 0.61815673 No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.02it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.06it/s]
[2024-04-29 03:04:30.759211] Completed epoch 18 with training loss 0.57561922, validation loss 0.61415452 No improvement in validation loss. 2 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 03:07:14.058624] Completed epoch 19 with training loss 0.57508308, validation loss 0.62138212 No improvement in validation loss. 3 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.02it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:09:57.098606] Completed epoch 20 with training loss 0.57473373, validation loss 0.61876047 No improvement in validation loss. 4 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 03:12:40.383532] Completed epoch 21 with training loss 0.57446361, validation loss 0.62025529 No improvement in validation loss. 5 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:15:23.834245] Completed epoch 22 with training loss 0.57377350, validation loss 0.64310312 No improvement in validation loss. 6 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:25<00:00, 3.99it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.09it/s]
[2024-04-29 03:18:07.978313] Completed epoch 23 with training loss 0.57257605, validation loss 0.64022654 No improvement in validation loss. 7 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.09it/s]
[2024-04-29 03:20:51.187422] Completed epoch 24 with training loss 0.57202220, validation loss 0.63490897 No improvement in validation loss. 8 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 03:23:34.438348] Completed epoch 25 with training loss 0.57152414, validation loss 0.64263254 No improvement in validation loss. 9 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.02it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.03it/s]
[2024-04-29 03:26:17.645174] Completed epoch 26 with training loss 0.57142615, validation loss 0.67803663 No improvement in validation loss. 10 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.02it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:29:00.730093] Completed epoch 27 with training loss 0.56990808, validation loss 0.61330938 Validation loss improved to 0.61330938. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:31:43.955264] Completed epoch 28 with training loss 0.57021445, validation loss 0.66097891 No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:23<00:00, 4.03it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.04it/s]
[2024-04-29 03:34:26.604996] Completed epoch 29 with training loss 0.56995690, validation loss 0.61245501 Validation loss improved to 0.61245501. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.03it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:37:09.422634] Completed epoch 30 with training loss 0.56896681, validation loss 0.63970363 No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:39:52.722066] Completed epoch 31 with training loss 0.56804359, validation loss 0.65943033 No improvement in validation loss. 2 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:42:36.254604] Completed epoch 32 with training loss 0.56741714, validation loss 0.63177902 No improvement in validation loss. 3 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 03:45:19.545220] Completed epoch 33 with training loss 0.56646210, validation loss 0.64889473 No improvement in validation loss. 4 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.02it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:48:02.703213] Completed epoch 34 with training loss 0.56666028, validation loss 0.63889188 No improvement in validation loss. 5 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 03:50:46.044629] Completed epoch 35 with training loss 0.56519240, validation loss 0.63445991 No improvement in validation loss. 6 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 03:53:29.363573] Completed epoch 36 with training loss 0.56389278, validation loss 0.68586659 No improvement in validation loss. 7 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.07it/s]
[2024-04-29 03:56:12.787647] Completed epoch 37 with training loss 0.56472218, validation loss 0.67719704 No improvement in validation loss. 8 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 03:58:56.054044] Completed epoch 38 with training loss 0.56313604, validation loss 0.68703187 No improvement in validation loss. 9 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s]
[2024-04-29 04:01:39.656013] Completed epoch 39 with training loss 0.56198382, validation loss 0.64501512 No improvement in validation loss. 10 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.01it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.04it/s]
[2024-04-29 04:04:23.048052] Completed epoch 40 with training loss 0.56228542, validation loss 0.62890518 No improvement in validation loss. 11 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.02it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 04:07:06.171371] Completed epoch 41 with training loss 0.56116742, validation loss 0.64826739 No improvement in validation loss. 12 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:24<00:00, 4.02it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.08it/s]
[2024-04-29 04:09:49.163348] Completed epoch 42 with training loss 0.56153125, validation loss 0.64055640 No improvement in validation loss. 13 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:25<00:00, 3.99it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.88it/s]
[2024-04-29 04:12:34.210185] Completed epoch 43 with training loss 0.56031489, validation loss 0.63237381 No improvement in validation loss. 14 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 580/580 [02:26<00:00, 3.95it/s] 100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.04it/s]
[2024-04-29 04:15:20.031952] Completed epoch 44 with training loss 0.55934858, validation loss 0.68684787 No improvement in validation loss. 15 epochs without improvement. Early stopping due to no improvement in validation loss. Loaded best model from disk from epoch 29.
HypotensionCNN(
(abpBlock1): ResidualBlock(
(bn1): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(1, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(residualConv): Conv1d(1, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(abpBlock2): ResidualBlock(
(bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
)
(abpBlock3): ResidualBlock(
(bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(abpBlock4): ResidualBlock(
(bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
)
(abpBlock5): ResidualBlock(
(bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(abpBlock6): ResidualBlock(
(bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(2, 4, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(4, 4, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
(residualConv): Conv1d(2, 4, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
)
(abpBlock7): ResidualBlock(
(bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(abpBlock8): ResidualBlock(
(bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
)
(abpBlock9): ResidualBlock(
(bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(downsample): MaxPool1d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
)
(abpBlock10): ResidualBlock(
(bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
)
(abpBlock11): ResidualBlock(
(bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(4, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(bn2): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(residualConv): Conv1d(4, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(abpBlock12): ResidualBlock(
(bn1): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.5, inplace=False)
(conv1): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(bn2): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
(residualConv): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
)
(abpFc): Linear(in_features=2814, out_features=32, bias=True)
(fullLinear1): Linear(in_features=32, out_features=16, bias=True)
(fullLinear2): Linear(in_features=16, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
Plot the training and validation losses after each epoch:
# Create x-axis values for epochs
epochs = range(0, len(train_losses))
plt.figure(figsize=(16, 9))
# Plot the training and validation losses
plt.plot(epochs, train_losses, 'b', label='Training Loss')
plt.plot(epochs, val_losses, 'r', label='Validation Loss')
# Add a vertical bar at the best_epoch
plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best Epoch')
# Shade everything to the right of the best_epoch a light red
plt.axvspan(best_epoch, max(epochs), facecolor='r', alpha=0.1)
# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
# Add legend
plt.legend(loc='upper right')
# Show the plot
plt.show()
def eval_model(model, dataloader):
model.eval()
model = model.to(device)
total_loss = 0
all_predictions = []
all_labels = []
with torch.no_grad():
for abp, ecg, eeg, label in tqdm(dataloader):
batch = len(abp)
abp = torch.nan_to_num(abp.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
ecg = torch.nan_to_num(ecg.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
eeg = torch.nan_to_num(eeg.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
label = label.type(torch.float).reshape(batch, 1).to(device)
pred = model(abp, eeg, ecg)
loss = loss_func(pred, label)
total_loss += loss.item()
all_predictions.append(pred.detach().cpu().numpy())
all_labels.append(label.detach().cpu().numpy())
# Flatten the lists
all_predictions = np.concatenate(all_predictions).flatten()
all_labels = np.concatenate(all_labels).flatten()
# Calculate AUROC and AUPRC
# y_true, y_pred
auroc = roc_auc_score(all_labels, all_predictions)
precision, recall, _ = precision_recall_curve(all_labels, all_predictions)
auprc = auc(recall, precision)
# Determine the optimal threshold, which is argmin(abs(sensitivity - specificity)) per the paper
thresholds = np.linspace(0, 1, 101) # 0 to 1 in 0.01 steps
min_diff = float('inf')
optimal_sensitivity = None
optimal_specificity = None
optimal_threshold = None
for threshold in thresholds:
all_predictions_binary = (all_predictions > threshold).astype(int)
tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions_binary).ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
diff = abs(sensitivity - specificity)
if diff < min_diff:
min_diff = diff
optimal_threshold = threshold
optimal_sensitivity = sensitivity
optimal_specificity = specificity
avg_loss = total_loss / len(dataloader)
return all_predictions, all_labels, avg_loss, auroc, auprc, optimal_sensitivity, optimal_specificity, optimal_threshold
# validation loop
valid_predictions, valid_labels, valid_loss, valid_auroc, valid_auprc, valid_sensitivity, valid_specificity, valid_threshold = eval_model(model, val_loader)
# test loop
test_predictions, test_labels, test_loss, test_auroc, test_auprc, test_sensitivity, test_specificity, test_threshold = eval_model(model, test_loader)
print(f'Best Epoch: {best_epoch}')
print()
print(f"Validation predictions: {valid_predictions}")
print(f"Validation labels: {valid_labels}")
print(f"Validation loss: {valid_loss}")
print(f"Validation AUROC: {valid_auroc}")
print(f"Validation AUPRC: {valid_auprc}")
print(f"Validation Sensitivity: {valid_sensitivity}")
print(f"Validation Specificity: {valid_specificity}")
print(f"Validation Threshold: {valid_threshold}")
print()
print(f"Test predictions: {test_predictions}")
print(f"Test labels: {test_labels}")
print(f"Test loss: {test_loss}")
print(f"Test AUROC: {test_auroc}")
print(f"Test AUPRC: {test_auprc}")
print(f"Test Sensitivity: {test_sensitivity}")
print(f"Test Specificity: {test_specificity}")
print(f"Test Threshold: {test_threshold}")
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:17<00:00, 5.40it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 4.99it/s]
Best Epoch: 29 Validation predictions: [0.20395379 0.47758454 0.08360817 ... 0.7062303 0.65294236 0.5635181 ] Validation labels: [0. 1. 1. ... 1. 1. 1.] Validation loss: 0.6123215374193693 Validation AUROC: 0.7350563857408707 Validation AUPRC: 0.652829437966435 Validation Sensitivity: 0.6706114398422091 Validation Specificity: 0.6827683615819209 Validation Threshold: 0.36 Test predictions: [0.4427195 0.49085984 0.35140878 ... 0.13020563 0.18610786 0.49174148] Test labels: [0. 0. 1. ... 0. 0. 0.] Test loss: 0.5801073525063569 Test AUROC: 0.7404521245699772 Test AUPRC: 0.6229591168454549 Test Sensitivity: 0.6794944618006248 Test Specificity: 0.6857569721115537 Test Threshold: 0.36
PRINT_DETAILED = False
val_aurocs = []
val_auprcs = []
test_aurocs = []
test_auprcs = []
for all_mod in all_models:
model.load_state_dict(torch.load(all_mod))
model.train(False)
# validation loop
valid_predictions, valid_labels, valid_loss, valid_auroc, valid_auprc, valid_sensitivity, valid_specificity, valid_threshold = eval_model(model, val_loader)
val_aurocs.append(valid_auroc)
val_auprcs.append(valid_auprc)
# test loop
test_predictions, test_labels, test_loss, test_auroc, test_auprc, test_sensitivity, test_specificity, test_threshold = eval_model(model, test_loader)
test_aurocs.append(test_auroc)
test_auprcs.append(test_auprc)
print(f'Model: {all_mod}')
if PRINT_DETAILED:
print(f"Validation predictions: {valid_predictions}")
print(f"Validation labels: {valid_labels}")
print(f"Validation loss: {valid_loss}")
print(f"Validation AUROC: {valid_auroc}")
print(f"Validation AUPRC: {valid_auprc}")
print(f"Validation Sensitivity: {valid_sensitivity}")
print(f"Validation Specificity: {valid_specificity}")
print(f"Validation Threshold: {valid_threshold}")
print()
if PRINT_DETAILED:
print(f"Test predictions: {test_predictions}")
print(f"Test labels: {test_labels}")
print(f"Test loss: {test_loss}")
print(f"Test AUROC: {test_auroc}")
print(f"Test AUPRC: {test_auprc}")
print(f"Test Sensitivity: {test_sensitivity}")
print(f"Test Specificity: {test_specificity}")
print(f"Test Threshold: {test_threshold}")
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.11it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0000.model Validation loss: 0.6387810267900166 Validation AUROC: 0.7400399491859727 Validation AUPRC: 0.6680950287017987 Validation Sensitivity: 0.6777120315581854 Validation Specificity: 0.6867231638418079 Validation Threshold: 0.31 Test loss: 0.5978416105775929 Test AUROC: 0.7463930960810738 Test AUPRC: 0.647821728567491 Test Sensitivity: 0.6962510650383413 Test Specificity: 0.6758798140770252 Test Threshold: 0.3
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.11it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0001.model Validation loss: 0.6443360002417313 Validation AUROC: 0.7403860640301319 Validation AUPRC: 0.6685993295731341 Validation Sensitivity: 0.6804733727810651 Validation Specificity: 0.6858757062146893 Validation Threshold: 0.29 Test loss: 0.5989955266981221 Test AUROC: 0.7467083290430063 Test AUPRC: 0.6484069544041045 Test Sensitivity: 0.6944049985799489 Test Specificity: 0.6782038512616202 Test Threshold: 0.28
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.01it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0002.model Validation loss: 0.6560433400304694 Validation AUROC: 0.7409460769565073 Validation AUPRC: 0.6687586092767344 Validation Sensitivity: 0.6800788954635109 Validation Specificity: 0.6830508474576271 Validation Threshold: 0.27 Test loss: 0.6064765780765476 Test AUROC: 0.747364637115648 Test AUPRC: 0.6489103176035561 Test Sensitivity: 0.6973871059358137 Test Specificity: 0.6752158034528553 Test Threshold: 0.26
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.07it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0003.model Validation loss: 0.6175579874139083 Validation AUROC: 0.7406613623953912 Validation AUPRC: 0.6676324365220729 Validation Sensitivity: 0.6796844181459566 Validation Specificity: 0.6850282485875706 Validation Threshold: 0.35000000000000003 Test loss: 0.5832182517817188 Test AUROC: 0.747331304574375 Test AUPRC: 0.647617180003468 Test Sensitivity: 0.6972451008236297 Test Specificity: 0.6767098273572377 Test Threshold: 0.34
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.97it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.07it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0004.model Validation loss: 0.6325045353487918 Validation AUROC: 0.7406738430336866 Validation AUPRC: 0.6673158150656753 Validation Sensitivity: 0.683629191321499 Validation Specificity: 0.6824858757062147 Validation Threshold: 0.31 Test loss: 0.5903931063254143 Test AUROC: 0.7472115938404859 Test AUPRC: 0.6476387963000785 Test Sensitivity: 0.6769383697813122 Test Specificity: 0.6982901726427623 Test Threshold: 0.31
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.01it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0005.model Validation loss: 0.6255835683722245 Validation AUROC: 0.7413703072242838 Validation AUPRC: 0.6654205137018345 Validation Sensitivity: 0.6706114398422091 Validation Specificity: 0.6926553672316385 Validation Threshold: 0.33 Test loss: 0.586878407410156 Test AUROC: 0.7485687162360687 Test AUPRC: 0.647233078036291 Test Sensitivity: 0.6951150241408691 Test Specificity: 0.6811918990703851 Test Threshold: 0.32
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.02it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0006.model Validation loss: 0.618489597659362 Validation AUROC: 0.7414619619117663 Validation AUPRC: 0.6655441951556209 Validation Sensitivity: 0.6804733727810651 Validation Specificity: 0.6878531073446328 Validation Threshold: 0.36 Test loss: 0.5859681487482128 Test AUROC: 0.7469606097054553 Test AUPRC: 0.6460891910068622 Test Sensitivity: 0.6755183186594718 Test Specificity: 0.6971281540504648 Test Threshold: 0.36
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.98it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0007.model Validation loss: 0.6318302487072192 Validation AUROC: 0.7416876163095198 Validation AUPRC: 0.6652683386940792 Validation Sensitivity: 0.6785009861932939 Validation Specificity: 0.6901129943502825 Validation Threshold: 0.31 Test loss: 0.5890876116062885 Test AUROC: 0.7479047822248825 Test AUPRC: 0.6467581742207538 Test Sensitivity: 0.6952570292530531 Test Specificity: 0.6799468791500664 Test Threshold: 0.3
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.07it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0008.model Validation loss: 0.6523008409299349 Validation AUROC: 0.7421522972174863 Validation AUPRC: 0.6680415661625254 Validation Sensitivity: 0.696646942800789 Validation Specificity: 0.6720338983050848 Validation Threshold: 0.27 Test loss: 0.604122545109146 Test AUROC: 0.7457633373539827 Test AUPRC: 0.6457187246457163 Test Sensitivity: 0.6914228912240841 Test Specificity: 0.6798638778220452 Test Threshold: 0.27
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0009.model Validation loss: 0.6813889679155851 Validation AUROC: 0.742232752760784 Validation AUPRC: 0.6652875171915722 Validation Sensitivity: 0.672189349112426 Validation Specificity: 0.6940677966101695 Validation Threshold: 0.24 Test loss: 0.6243627961961721 Test AUROC: 0.7471383400413305 Test AUPRC: 0.6470996655350998 Test Sensitivity: 0.6942629934677649 Test Specificity: 0.6784528552456839 Test Threshold: 0.23
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.98it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.03it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0010.model Validation loss: 0.6229492284749684 Validation AUROC: 0.7424219681520855 Validation AUPRC: 0.6648559140978972 Validation Sensitivity: 0.6974358974358974 Validation Specificity: 0.6771186440677966 Validation Threshold: 0.32 Test loss: 0.5838596717669413 Test AUROC: 0.7485962320838769 Test AUPRC: 0.6453448234248486 Test Sensitivity: 0.690996875887532 Test Specificity: 0.685175962815405 Test Threshold: 0.32
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.07it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0011.model Validation loss: 0.615843253386648 Validation AUROC: 0.743164398979262 Validation AUPRC: 0.6652488557445398 Validation Sensitivity: 0.6887573964497041 Validation Specificity: 0.6889830508474576 Validation Threshold: 0.37 Test loss: 0.5858555816886417 Test AUROC: 0.7469759499821409 Test AUPRC: 0.6407476032306707 Test Sensitivity: 0.6813405282590174 Test Specificity: 0.6931440903054449 Test Threshold: 0.37
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.83it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.02it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0012.model Validation loss: 0.6794318280721966 Validation AUROC: 0.7446976231070104 Validation AUPRC: 0.6689777605957021 Validation Sensitivity: 0.6903353057199211 Validation Specificity: 0.6838983050847458 Validation Threshold: 0.24 Test loss: 0.6236616017228385 Test AUROC: 0.7471523071776135 Test AUPRC: 0.6436558892166175 Test Sensitivity: 0.6868787276341949 Test Specificity: 0.689160026560425 Test Threshold: 0.24
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.92it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.06it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0013.model Validation loss: 0.6628241112357691 Validation AUROC: 0.7452099978827489 Validation AUPRC: 0.6694381664148772 Validation Sensitivity: 0.6927021696252466 Validation Specificity: 0.6861581920903955 Validation Threshold: 0.25 Test loss: 0.6102634491928445 Test AUROC: 0.7485396563419708 Test AUPRC: 0.6445269054437672 Test Sensitivity: 0.687304742970747 Test Specificity: 0.6886620185922975 Test Threshold: 0.25
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0014.model Validation loss: 0.6220772391871402 Validation AUROC: 0.7427968887551677 Validation AUPRC: 0.6643071695754124 Validation Sensitivity: 0.6796844181459566 Validation Specificity: 0.6946327683615819 Validation Threshold: 0.33 Test loss: 0.5835881430070137 Test AUROC: 0.7465738261099311 Test AUPRC: 0.6410318063862939 Test Sensitivity: 0.6972451008236297 Test Specificity: 0.6776228419654714 Test Threshold: 0.32
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.02it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.11it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0015.model Validation loss: 0.6675321958566967 Validation AUROC: 0.7449018821248287 Validation AUPRC: 0.666619496490892 Validation Sensitivity: 0.6816568047337278 Validation Specificity: 0.6915254237288135 Validation Threshold: 0.25 Test loss: 0.6137017674968395 Test AUROC: 0.7481947211155378 Test AUPRC: 0.6415150239138332 Test Sensitivity: 0.68389662027833 Test Specificity: 0.6921480743691899 Test Threshold: 0.25
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.99it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0016.model Validation loss: 0.6139778049368607 Validation AUROC: 0.7423028449169258 Validation AUPRC: 0.6630974839911995 Validation Sensitivity: 0.6824457593688363 Validation Specificity: 0.68954802259887 Validation Threshold: 0.37 Test loss: 0.5833295102302845 Test AUROC: 0.7467699258820064 Test AUPRC: 0.6339391185656855 Test Sensitivity: 0.6861687020732746 Test Specificity: 0.6895750332005313 Test Threshold: 0.37
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0017.model Validation loss: 0.6180980236906755 Validation AUROC: 0.742333210755636 Validation AUPRC: 0.6655740917558959 Validation Sensitivity: 0.6930966469428008 Validation Specificity: 0.6754237288135593 Validation Threshold: 0.33 Test loss: 0.581086583859147 Test AUROC: 0.7463296958903003 Test AUPRC: 0.6374448946695843 Test Sensitivity: 0.6962510650383413 Test Specificity: 0.6791998671978752 Test Threshold: 0.33
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.07it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0018.model Validation loss: 0.6140767812728882 Validation AUROC: 0.7413853508508006 Validation AUPRC: 0.6612301449573781 Validation Sensitivity: 0.6824457593688363 Validation Specificity: 0.6884180790960452 Validation Threshold: 0.36 Test loss: 0.5817086570039641 Test AUROC: 0.7460766608846259 Test AUPRC: 0.6335320295820634 Test Sensitivity: 0.6816245384833854 Test Specificity: 0.6922310756972112 Test Threshold: 0.36
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.02it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.03it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0019.model Validation loss: 0.6214054054335544 Validation AUROC: 0.7434634885612721 Validation AUPRC: 0.6624512477726926 Validation Sensitivity: 0.6891518737672584 Validation Specificity: 0.6822033898305084 Validation Threshold: 0.32 Test loss: 0.5820028808304298 Test AUROC: 0.7484099387737322 Test AUPRC: 0.6355571539734302 Test Sensitivity: 0.6908548707753479 Test Specificity: 0.6857569721115537 Test Threshold: 0.32
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0020.model Validation loss: 0.6188408553600311 Validation AUROC: 0.7408133587403469 Validation AUPRC: 0.6585518671807284 Validation Sensitivity: 0.685207100591716 Validation Specificity: 0.6824858757062147 Validation Threshold: 0.34 Test loss: 0.5820679765281869 Test AUROC: 0.7451261766339922 Test AUPRC: 0.6279967811300393 Test Sensitivity: 0.6891508094291394 Test Specificity: 0.6855909694555112 Test Threshold: 0.34
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.83it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.01it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0021.model Validation loss: 0.6202563963438336 Validation AUROC: 0.7424114933306588 Validation AUPRC: 0.6616424972447899 Validation Sensitivity: 0.6749506903353057 Validation Specificity: 0.6926553672316385 Validation Threshold: 0.33 Test loss: 0.5807039358444437 Test AUROC: 0.7469517933190084 Test AUPRC: 0.6338937701042213 Test Sensitivity: 0.6980971314967339 Test Specificity: 0.6783698539176627 Test Threshold: 0.32
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.02it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.03it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0022.model Validation loss: 0.6431774635063975 Validation AUROC: 0.7424755123190587 Validation AUPRC: 0.661262869147858 Validation Sensitivity: 0.6820512820512821 Validation Specificity: 0.6878531073446328 Validation Threshold: 0.28 Test loss: 0.5952648379930285 Test AUROC: 0.7472981311523763 Test AUPRC: 0.6334205465597789 Test Sensitivity: 0.6865947174098268 Test Specificity: 0.6903220451527224 Test Threshold: 0.28
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0023.model Validation loss: 0.6402282062329744 Validation AUROC: 0.7416899564292002 Validation AUPRC: 0.658668843749461 Validation Sensitivity: 0.693491124260355 Validation Specificity: 0.6711864406779661 Validation Threshold: 0.28 Test loss: 0.5929512529568529 Test AUROC: 0.7468907681307337 Test AUPRC: 0.6310238989730499 Test Sensitivity: 0.6772223800056802 Test Specificity: 0.6979581673306773 Test Threshold: 0.29
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.02it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0024.model Validation loss: 0.6349092988591445 Validation AUROC: 0.741624711663825 Validation AUPRC: 0.6606746417424187 Validation Sensitivity: 0.6840236686390533 Validation Specificity: 0.6819209039548022 Validation Threshold: 0.29 Test loss: 0.5890709742853872 Test AUROC: 0.7471263412694013 Test AUPRC: 0.6325462074403769 Test Sensitivity: 0.6897188298778756 Test Specificity: 0.6845949535192563 Test Threshold: 0.29
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0025.model Validation loss: 0.642772879098591 Validation AUROC: 0.737960195678579 Validation AUPRC: 0.6534923519875904 Validation Sensitivity: 0.6824457593688363 Validation Specificity: 0.676271186440678 Validation Threshold: 0.29 Test loss: 0.5974743700147074 Test AUROC: 0.7417247251644752 Test AUPRC: 0.6233899566645328 Test Sensitivity: 0.6880147685316671 Test Specificity: 0.6796978751660027 Test Threshold: 0.29
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0026.model Validation loss: 0.6780585169792175 Validation AUROC: 0.7420618683069792 Validation AUPRC: 0.6608936915359056 Validation Sensitivity: 0.683629191321499 Validation Specificity: 0.6822033898305084 Validation Threshold: 0.23 Test loss: 0.6201831335108415 Test AUROC: 0.7466249917965174 Test AUPRC: 0.6329448254267589 Test Sensitivity: 0.6902868503266117 Test Specificity: 0.6866699867197875 Test Threshold: 0.23
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.07it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0027.model Validation loss: 0.613245994166324 Validation AUROC: 0.7329155662532455 Validation AUPRC: 0.6451171916280873 Validation Sensitivity: 0.6875739644970414 Validation Specificity: 0.6694915254237288 Validation Threshold: 0.37 Test loss: 0.5832619988998043 Test AUROC: 0.7396088867100942 Test AUPRC: 0.6148930652086293 Test Sensitivity: 0.6898608349900597 Test Specificity: 0.6737217795484728 Test Threshold: 0.37
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:58<00:00, 5.08it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0028.model Validation loss: 0.6610989307102404 Validation AUROC: 0.7362598201450874 Validation AUPRC: 0.6489924315201367 Validation Sensitivity: 0.6863905325443787 Validation Specificity: 0.6714689265536723 Validation Threshold: 0.26 Test loss: 0.6090877181710208 Test AUROC: 0.7428379589659917 Test AUPRC: 0.6221493934116767 Test Sensitivity: 0.6907128656631639 Test Specificity: 0.6787018592297477 Test Threshold: 0.26
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.01it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [01:00<00:00, 4.97it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0029.model Validation loss: 0.6124660962506344 Validation AUROC: 0.7350563857408707 Validation AUPRC: 0.652829437966435 Validation Sensitivity: 0.6706114398422091 Validation Specificity: 0.6827683615819209 Validation Threshold: 0.36 Test loss: 0.5801073525063569 Test AUROC: 0.7404521245699772 Test AUPRC: 0.6229591168454549 Test Sensitivity: 0.6794944618006248 Test Specificity: 0.6857569721115537 Test Threshold: 0.36
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0030.model Validation loss: 0.63962740741278 Validation AUROC: 0.7378741684217565 Validation AUPRC: 0.6556341073811118 Validation Sensitivity: 0.6761341222879684 Validation Specificity: 0.6830508474576271 Validation Threshold: 0.28 Test loss: 0.5900731322956723 Test AUROC: 0.7466976032252699 Test AUPRC: 0.6274312508572893 Test Sensitivity: 0.6848906560636183 Test Specificity: 0.6889110225763613 Test Threshold: 0.28
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:18<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0031.model Validation loss: 0.659438492436158 Validation AUROC: 0.7381125263263464 Validation AUPRC: 0.6572125712820158 Validation Sensitivity: 0.6824457593688363 Validation Specificity: 0.6717514124293785 Validation Threshold: 0.25 Test loss: 0.6057033517886963 Test AUROC: 0.743899991721083 Test AUPRC: 0.6272753853847262 Test Sensitivity: 0.6912808861119001 Test Specificity: 0.6780378486055777 Test Threshold: 0.25
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.80it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.06it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0032.model Validation loss: 0.6318849322042968 Validation AUROC: 0.7311915666544089 Validation AUPRC: 0.648058472432239 Validation Sensitivity: 0.6729783037475345 Validation Specificity: 0.669774011299435 Validation Threshold: 0.31 Test loss: 0.5905645691032793 Test AUROC: 0.7376491973033739 Test AUPRC: 0.615506542095931 Test Sensitivity: 0.679920477137177 Test Specificity: 0.6783698539176627 Test Threshold: 0.31
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 5.00it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0033.model Validation loss: 0.6488435707594219 Validation AUROC: 0.7338580773130969 Validation AUPRC: 0.6480515000096597 Validation Sensitivity: 0.6848126232741617 Validation Specificity: 0.6627118644067796 Validation Threshold: 0.27 Test loss: 0.5998478742148167 Test AUROC: 0.7412353332103754 Test AUPRC: 0.6194080637962579 Test Sensitivity: 0.6938369781312127 Test Specificity: 0.671812749003984 Test Threshold: 0.27
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.99it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.07it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0034.model Validation loss: 0.6389461636543274 Validation AUROC: 0.7318215045855203 Validation AUPRC: 0.6480751694163766 Validation Sensitivity: 0.6824457593688363 Validation Specificity: 0.6627118644067796 Validation Threshold: 0.29 Test loss: 0.5943676391772204 Test AUROC: 0.7390753656773078 Test AUPRC: 0.6167380358644224 Test Sensitivity: 0.6907128656631639 Test Specificity: 0.671480743691899 Test Threshold: 0.29
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.97it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.06it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0035.model Validation loss: 0.6346045327814002 Validation AUROC: 0.7312213196046312 Validation AUPRC: 0.6475339512325334 Validation Sensitivity: 0.6694280078895464 Validation Specificity: 0.6731638418079096 Validation Threshold: 0.3 Test loss: 0.5890515674416437 Test AUROC: 0.7393132902452484 Test AUPRC: 0.6168524512077759 Test Sensitivity: 0.6807725078102812 Test Specificity: 0.6800298804780877 Test Threshold: 0.3
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.84it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0036.model Validation loss: 0.6858358022413755 Validation AUROC: 0.7342748414847502 Validation AUPRC: 0.6508414235905711 Validation Sensitivity: 0.6812623274161735 Validation Specificity: 0.669774011299435 Validation Threshold: 0.22 Test loss: 0.6251822330780252 Test AUROC: 0.7412007866479741 Test AUPRC: 0.6221875671594421 Test Sensitivity: 0.6878727634194831 Test Specificity: 0.6761288180610889 Test Threshold: 0.22
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.99it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.06it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0037.model Validation loss: 0.6772078978387933 Validation AUROC: 0.7302482755546641 Validation AUPRC: 0.6461552202724268 Validation Sensitivity: 0.665483234714004 Validation Specificity: 0.6757062146892655 Validation Threshold: 0.24 Test loss: 0.6201854038697022 Test AUROC: 0.736949255189787 Test AUPRC: 0.6153724890408299 Test Sensitivity: 0.6770803748934962 Test Specificity: 0.6842629482071713 Test Threshold: 0.24
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.96it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.04it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0038.model Validation loss: 0.687013596610019 Validation AUROC: 0.7295055661418113 Validation AUPRC: 0.6436243860793894 Validation Sensitivity: 0.6611439842209073 Validation Specificity: 0.6819209039548022 Validation Threshold: 0.23 Test loss: 0.6259211399981808 Test AUROC: 0.7380724133193628 Test AUPRC: 0.6152906829613143 Test Sensitivity: 0.6718261857426867 Test Specificity: 0.689824037184595 Test Threshold: 0.23
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.99it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0039.model Validation loss: 0.6451818936749508 Validation AUROC: 0.732067217151963 Validation AUPRC: 0.6491256207800374 Validation Sensitivity: 0.6694280078895464 Validation Specificity: 0.6805084745762712 Validation Threshold: 0.28 Test loss: 0.5959248230309789 Test AUROC: 0.7392944964909839 Test AUPRC: 0.6194771748514758 Test Sensitivity: 0.6794944618006248 Test Specificity: 0.6846779548472776 Test Threshold: 0.28
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.98it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0040.model Validation loss: 0.6290142667920966 Validation AUROC: 0.7258809436254026 Validation AUPRC: 0.6430361913762791 Validation Sensitivity: 0.6650887573964497 Validation Specificity: 0.6799435028248587 Validation Threshold: 0.32 Test loss: 0.5884272557835913 Test AUROC: 0.7326805034901576 Test AUPRC: 0.6112355301964911 Test Sensitivity: 0.6722522010792389 Test Specificity: 0.6839309428950863 Test Threshold: 0.32
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.95it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0041.model Validation loss: 0.6481288282494796 Validation AUROC: 0.7215054769943949 Validation AUPRC: 0.6402835733086201 Validation Sensitivity: 0.6717948717948717 Validation Specificity: 0.6624293785310734 Validation Threshold: 0.28 Test loss: 0.6026340150912869 Test AUROC: 0.7264968220462842 Test AUPRC: 0.6055441639334893 Test Sensitivity: 0.6766543595569441 Test Specificity: 0.6673306772908366 Test Threshold: 0.28
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.99it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.04it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0042.model Validation loss: 0.6405608804602372 Validation AUROC: 0.7280121240486299 Validation AUPRC: 0.6449519225936438 Validation Sensitivity: 0.6627218934911243 Validation Specificity: 0.6802259887005649 Validation Threshold: 0.29 Test loss: 0.5936258033565853 Test AUROC: 0.7354920173042565 Test AUPRC: 0.6138849803678386 Test Sensitivity: 0.6752343084351037 Test Specificity: 0.6850929614873837 Test Threshold: 0.29
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.99it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0043.model Validation loss: 0.6324820116946571 Validation AUROC: 0.724718461315593 Validation AUPRC: 0.6406671987200219 Validation Sensitivity: 0.663905325443787 Validation Specificity: 0.6728813559322034 Validation Threshold: 0.31 Test loss: 0.5894579508631524 Test AUROC: 0.7320272894034013 Test AUPRC: 0.6099001558825158 Test Sensitivity: 0.6766543595569441 Test Specificity: 0.6770418326693227 Test Threshold: 0.31
100%|███████████████████████████████████████████████████████████████████████████████████████| 95/95 [00:19<00:00, 4.96it/s] 100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.05it/s]
Model: ./vitaldb_cache/models/ABP_ALL_MINS_0044.model Validation loss: 0.6869757495428387 Validation AUROC: 0.7201951214076376 Validation AUPRC: 0.635576515376461 Validation Sensitivity: 0.67534516765286 Validation Specificity: 0.6598870056497175 Validation Threshold: 0.23 Test loss: 0.6280749833902787 Test AUROC: 0.7273277546823027 Test AUPRC: 0.6026586503188287 Test Sensitivity: 0.679920477137177 Test Specificity: 0.662765604249668 Test Threshold: 0.23
# Create x-axis values for epochs
epochs = range(0, len(val_aurocs))
# Find model with highest AUROC
np_test_aurocs = np.array(test_aurocs)
test_auroc_idx = np.argmax(np_test_aurocs)
print(f'Epoch with best Validation Loss: {best_epoch:3}, {val_losses[best_epoch]:.4}')
print(f'Epoch with best model Test AUROC: {test_auroc_idx:3}, {np.max(np_test_aurocs):.4}')
print(f'Best Model on Validation Loss: {all_models[test_auroc_idx]}')
print(f'Best Model on Test AUROC: {all_models[best_epoch]}')
plt.figure(figsize=(16, 9))
# Plot the training and validation losses
plt.plot(epochs, val_aurocs, 'C0', label='AUROC - Validation')
plt.plot(epochs, test_aurocs, 'C1', label='AUROC - Test')
plt.plot(epochs, val_auprcs, 'C2', label='AUPRC - Validation')
plt.plot(epochs, test_auprcs, 'C3', label='AUPRC - Test')
# Add a vertical bar at the best_epoch
plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best Epoch - Validation Loss')
plt.axvline(x=test_auroc_idx, color='maroon', linestyle='--', label='Best Epoch - Test AUROC')
# Shade everything to the right of the best_model a light red
plt.axvspan(test_auroc_idx, max(epochs), facecolor='r', alpha=0.1)
# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('AUROC / AUPRC')
plt.title('Validation and Test AUROC by Model Iteration Across Training')
# Add legend
plt.legend(loc='right')
# Show the plot
plt.show()
Epoch with best Validation Loss: 29, 0.6125 Epoch with best model Test AUROC: 10, 0.7486 Best Model on Validation Loss: ./vitaldb_cache/models/ABP_ALL_MINS_0010.model Best Model on Test AUROC: ./vitaldb_cache/models/ABP_ALL_MINS_0029.model
best_model_val_loss = all_models[best_epoch]
print(f'Best Model Based on Validation Loss: {best_model_val_loss}')
model.load_state_dict(torch.load(best_model_val_loss))
model.train(False)
(best_model_val_test_predictions, best_model_val_test_labels, test_loss,
test_auroc, best_model_val_test_auprc, test_sensitivity, test_specificity, best_model_val_test_threshold) \
= eval_model(model, test_loader)
Best Model Based on Validation Loss: ./vitaldb_cache/models/ABP_ALL_MINS_0029.model
100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.03it/s]
# y_test, y_pred
display = RocCurveDisplay.from_predictions(
best_model_val_test_labels,
best_model_val_test_predictions,
plot_chance_level=True
)
plt.show()
roc_auc_score(best_model_val_test_labels, best_model_val_test_predictions)
0.7404521245699772
best_model_val_test_predictions_binary = \
(best_model_val_test_predictions > best_model_val_test_threshold).astype(int)
# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
best_model_val_test_labels,
best_model_val_test_predictions_binary,
plot_chance_level=True
)
plt.show()
# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
best_model_val_test_labels,
best_model_val_test_predictions,
plot_chance_level=True
)
plt.show()
best_model_val_test_auprc
0.6229591168454549
best_model_auroc = all_models[test_auroc_idx]
print(f'Best Model Based on Model AUROC: {best_model_auroc}')
model.load_state_dict(torch.load(best_model_auroc))
model.train(False)
(best_model_auroc_test_predictions, best_model_auroc_test_labels, test_loss,
test_auroc, best_model_auroc_test_auprc, test_sensitivity, test_specificity, best_model_auroc_test_threshold) \
= eval_model(model, test_loader)
Best Model Based on Model AUROC: ./vitaldb_cache/models/ABP_ALL_MINS_0010.model
100%|█████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:59<00:00, 5.04it/s]
# y_test, y_pred
display = RocCurveDisplay.from_predictions(
best_model_auroc_test_labels,
best_model_auroc_test_predictions,
plot_chance_level=True
)
plt.show()
roc_auc_score(best_model_auroc_test_labels, best_model_auroc_test_predictions)
0.7485962320838769
best_model_auroc_test_predictions_binary = \
(best_model_auroc_test_predictions > best_model_auroc_test_threshold).astype(int)
# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
best_model_auroc_test_labels,
best_model_auroc_test_predictions_binary,
plot_chance_level=True
)
plt.show()
# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
best_model_auroc_test_labels,
best_model_auroc_test_predictions,
plot_chance_level=True
)
plt.show()
best_model_auroc_test_auprc
0.6453448234248486
When we complete our experiments, we will build comparison tables that compare a set of measures for each experiment performed. The full set of experiments and measures are listed below.
Note: each experiment will be repeated with the following time-to-IOH-event durations:
Note: the above list of experiments will be performed if there is sufficient time and gpu capability to complete that before the submission deadline. Should we experience any constraints on this front, we will reduce our experimental coverage to the following 4 core experiments that are necessary to measure the hypotheses included at the head of this report:
For additional details please review the "Planned Actions" in the Discussion section of this report.
[ TODO for final report - collect data for all measures listed above. ]
[ TODO for final report - generate ROC and PRC plots for each experiment ]
We are collecting a broad set of measures across each experiment in order to perform a comprehensive comparison of all measures listed across all comparable experiments executed in the original paper. However, our key experimental results will be focused on a subset of these results that address the main experiments defined at the beginning of this notebook.
The key experimental result measures will be as follows:
The following table is Table 3 from the original paper which presents the measured values for each signal combination across each of the four temporal predictive categories:
We have not yet completed the execution of the experiments necessary to determine our reproduced model performance in order determine whether our results are accurately representing those of the original paper. These details are expected to be included in the final report.
As of the draft submission, the reported evaluation measures of our model are too good to be true (all measures are 1.0). We suspect that there is data leakage in the dataset splitting process and will address this in time for the final report.
Our assessment is that this paper will be reproducible. The outstanding risk is that each experiment can take up to 7 hours to run on hardware within the team (i.e., 7h to run ~70 epochs on a desktop with AMD Ryzen 7 3800X 8-core CPU w/ RTX 2070 SUPER GPU and 32GB RAM). There are a total of 28 experiments (7 different combinations of signal inputs, 4 different time horizons for each combination). Should our team find it not possible to complete the necessary experiments across all of the experiments represented in Table 3 of our selected paper, we will reduce the number of experiments to focus solely on the ones directly related to our hypotheses described in the beginning of this notebook (i.e., reduce the number of combinations of interest to 4: ABP alone, ABP+EEG, ABP+ECG, ABP+ECG+EEG). This will result in a new total of 16 experiments to run.
Our proposal included a collection of potential ablations to be investigated:
Given the amount of time required to conduct each experiment, our team intends to choose only a small number of ablations from this set. Further, we only intend to perform ablation analysis against the best performing signal combination and time horizon from the reproduction experiments. In order words, we intend to perform ablation analysis against the following training combinations, and only against the models trained with data measured 3 minutes prior to an IOH event:
Time and GPU resource permitting, we will complete a broader range of experiments. For additional details, please see the section below titled "Plans for next phase".
Our team intends to address the manner in which the experimental results align with the published results in the paper in the final submission of this report. The amount of time required to complete model training and result analysis during the preparation of the Draft notebook was not sufficient to complete a large number of experiments.
The difficult aspect of the preparation of this draft involved the data preprocessing.
The most notable suggestion would be to correct the hyperparameters published in Supplemental Table 1. Specifically, the output size for residual blocks 11 and 12 for the ECG and ABP data sets was 496x6. This is a typo, and should read 469x6. This typo became apparent when operating the size down operation within Residual Block 11 and recognizing the tensor dimensions were misaligned.
Additionally, more explicit references to the signal quality index assessment tools should be added. Our team could not find a reference to the MATLAB source code as described in reference [3], and had to manually discover the GitHub profile for the lab of the corresponding author of reference [3] in order to find MATLAB source that corresponded to the metrics described therein.
Our team plans to accomplish the following goals in service of preparing the Final Report:
Walkthrough of the notebook, no need to make slides. We expect a well-timed, well-presented presentation. You should clearly explain what the original paper is about (what the general problem is, what the specific approach taken was, and what the results claimed were) and what you encountered when you attempted to reproduce the results. You should use the time given to you and not too much (or too little).